Skip to content

Commit e2f194a

Browse files
committed
make training false
1 parent 3f4eb53 commit e2f194a

File tree

4 files changed

+12
-13
lines changed

4 files changed

+12
-13
lines changed

keras_hub/src/models/smollm3/smollm3_backbone.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,4 +176,4 @@ def get_config(self):
176176
"partial_rotary_factor": self.partial_rotary_factor,
177177
}
178178
)
179-
return config
179+
return config

keras_hub/src/models/smollm3/smollm3_layers.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,9 @@ def call(
121121
input_shape = ops.shape(hidden_states)[:-1]
122122
hidden_shape = (*input_shape, self.num_attention_heads, self.head_dim)
123123

124-
query_states = ops.reshape(self.q_proj(hidden_states),hidden_shape)
124+
query_states = ops.reshape(self.q_proj(hidden_states), hidden_shape)
125125
# (batch, num_heads, seq_len, head_dim)
126-
query_states = ops.transpose(query_states, axes=(0, 2, 1, 3))
126+
query_states = ops.transpose(query_states, axes=(0, 2, 1, 3))
127127

128128
def _compute_kv_values(x_input):
129129
kv_hidden_shape = (
@@ -132,13 +132,9 @@ def _compute_kv_values(x_input):
132132
self.head_dim,
133133
)
134134

135-
key_states_raw = ops.reshape(
136-
self.k_proj(x_input),
137-
kv_hidden_shape
138-
)
135+
key_states_raw = ops.reshape(self.k_proj(x_input), kv_hidden_shape)
139136
value_states_raw = ops.reshape(
140-
self.v_proj(x_input),
141-
kv_hidden_shape
137+
self.v_proj(x_input), kv_hidden_shape
142138
)
143139

144140
key_states = ops.transpose(key_states_raw, axes=(0, 2, 1, 3))
@@ -155,7 +151,9 @@ def _compute_kv_values(x_input):
155151
key_states = key_cache
156152
value_states = value_cache
157153
else:
158-
print("self_attention_cache_update_index is not None, computing kv values")
154+
print(
155+
"self_attention_cache_update_index is not None, computing kv values"
156+
)
159157
key_update, value_update = _compute_kv_values(hidden_states)
160158
update_idx_tensor = ops.convert_to_tensor(
161159
self_attention_cache_update_index, dtype="int32"
@@ -417,7 +415,7 @@ def call(
417415
position_embeddings: Optional tuple of (cos, sin) tensors for RoPE.
418416
training: Whether the layer is in training mode.
419417
"""
420-
self_attention_cache = kwargs.get('self_attention_cache', None)
418+
self_attention_cache = kwargs.get("self_attention_cache", None)
421419

422420
residual = hidden_states
423421
hidden_states = self.input_layernorm(hidden_states)

keras_hub/src/models/smollm3/smollm3_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def eager_attention_forward(
5252
attn_weights = ops.add(attn_weights, causal_mask)
5353

5454
attn_weights = ops.softmax(attn_weights, axis=-1)
55-
55+
5656
if training:
5757
attn_weights = random.dropout(attn_weights, rate=dropout)
5858
attn_output = ops.matmul(attn_weights, value_states)

keras_hub/src/utils/transformers/convert_smollm3.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ def transpose_and_reshape(x, shape):
5252
)
5353

5454
# Attention layers
55-
5655
## Query
5756
loader.port_weight(
5857
keras_variable=decoder_layer.self_attn.q_proj.kernel,
@@ -110,6 +109,8 @@ def transpose_and_reshape(x, shape):
110109
hf_weight_key="model.norm.weight",
111110
)
112111

112+
backbone.training = False
113+
113114
return backbone
114115

115116

0 commit comments

Comments
 (0)