Skip to content

Commit bb3d157

Browse files
Hilly12recml authors
authored andcommitted
Add a clone initializer utility.
PiperOrigin-RevId: 764467300
1 parent 2b54892 commit bb3d157

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

recml/layers/keras/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@
1919
Tensor = Any
2020

2121

22+
def clone_initializer(initializer: Any) -> Any:
23+
"""Clones an initializer."""
24+
if isinstance(initializer, keras.initializers.Initializer):
25+
return initializer.clone()
26+
return initializer
27+
28+
2229
def make_attention_mask(mask: Tensor, dtype: str = "float32") -> Tensor:
2330
"""Creates a 3D self-attention mask from a padding mask."""
2431
# Element wise pairwise function on [B, L, 1], [B, 1, L].

recml/layers/keras/utils_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,17 @@
2323

2424
class UtilsTest(testing.TestCase):
2525

26+
def test_clone_initializer(self):
27+
random_initializer = keras.initializers.RandomNormal(stddev=1.0)
28+
random_initializer_clone = utils.clone_initializer(random_initializer)
29+
self.assertNotEqual(random_initializer.seed, random_initializer_clone.seed)
30+
31+
lecun_initializer = keras.initializers.LecunNormal(seed=1)
32+
lecun_initializer_clone = utils.clone_initializer(lecun_initializer)
33+
self.assertEqual(lecun_initializer.seed, lecun_initializer_clone.seed)
34+
35+
self.assertEqual(utils.clone_initializer("lecun_normal"), "lecun_normal")
36+
2637
# Remember to read these sideways =))
2738
@parameterized.parameters(
2839
dict(

0 commit comments

Comments
 (0)