Skip to content

Commit d8c0a87

Browse files
committed
fix
Signed-off-by: HuiyingLi <[email protected]>
1 parent 5f27227 commit d8c0a87

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

tests/unit_tests/models/gpt_oss/test_gptoss_rope_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,9 @@ def test_dtype_preserved_with_bfloat16_input(self):
216216

217217
# Create bfloat16 input
218218
x_bf16 = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=torch.bfloat16)
219-
cos = torch.randn(seq_len, head_dim // 2)
220-
sin = torch.randn(seq_len, head_dim // 2)
219+
angles = torch.randn(seq_len, head_dim // 2)
220+
cos = angles.cos()
221+
sin = angles.sin()
221222

222223
# Apply rotary embedding
223224
result = apply_rotary_emb(x_bf16, cos, sin)

0 commit comments

Comments
 (0)