Skip to content

Commit bf5bd5f

Browse files
authored
fix mx kernel tests (#2614)
Update [ghstack-poisoned]
1 parent ebfe173 commit bf5bd5f

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

torchao/prototype/mx_formats/kernels.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1394,7 +1394,7 @@ def triton_to_mxfp8_dim1_reference(
13941394
scale_e8m0_dim1 = scale_e8m0_dim1.view(torch.float8_e8m0fnu)
13951395
return (
13961396
x_hp_d1_normalized.t(),
1397-
scale_e8m0_dim1,
1397+
scale_e8m0_dim1.unsqueeze(-1),
13981398
)
13991399

14001400
@triton.jit

0 commit comments

Comments
 (0)