Skip to content

Commit a663a5c

Browse files
committed
refactor rotary embeddings
1 parent f913179 commit a663a5c

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

keras_hub/src/models/smollm3/smollm3_layers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ def call(
606606
inv_freq_expanded, (batch_size, ops.shape(self.inv_freq)[0], 1)
607607
)
608608

609-
position_ids_expanded = ops.expand_dims(positions, axis=1)
609+
position_ids_expanded = ops.expand_dims(positions, axis=1).T
610610

611611
freqs = ops.matmul(
612612
ops.cast(inv_freq_expanded, "float32"),

0 commit comments

Comments
 (0)