Skip to content

Commit e17dd99

Browse files
committed
pass dtype policy
1 parent 5a6fb27 commit e17dd99

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

keras_hub/src/models/smollm3/smollm3_backbone.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def __init__(
112112
self.rotary_embedding = RotaryEmbedding(
113113
max_wavelength=rope_theta,
114114
scaling_factor=rope_scaling,
115-
dtype=self.dtype_policy
115+
dtype=self.token_embedding.dtype_policy
116116
)
117117

118118
# === Functional Model ===
@@ -161,6 +161,7 @@ def __init__(
161161
self.max_position_embeddings = max_position_embeddings
162162
self.rope_theta = rope_theta
163163
self.partial_rotary_factor = partial_rotary_factor
164+
self.rope_scaling = rope_scaling
164165

165166
def get_config(self):
166167
config = super().get_config()
@@ -181,6 +182,7 @@ def get_config(self):
181182
"max_position_embeddings": self.max_position_embeddings,
182183
"rope_theta": self.rope_theta,
183184
"partial_rotary_factor": self.partial_rotary_factor,
185+
"rope_scaling": self.rope_scaling
184186
}
185187
)
186188
return config

0 commit comments

Comments
 (0)