File tree Expand file tree Collapse file tree 2 files changed +11
-1
lines changed
torchao/prototype/mx_formats Expand file tree Collapse file tree 2 files changed +11
-1
lines changed Original file line number Diff line number Diff line change @@ -1448,6 +1448,7 @@ def triton_scale_swizzle(
1448
1448
scales_flat ,
1449
1449
)
1450
1450
1451
+ @torch .library .custom_op ("torchao::triton_mx_block_rearrange" , mutates_args = ())
1451
1452
def triton_mx_block_rearrange (scale_tensor : torch .Tensor ) -> torch .Tensor :
1452
1453
"""
1453
1454
Rearranges an E8M0 tensor scale from row-major format to block-scaled swizzle format.
@@ -1716,6 +1717,15 @@ def _(x, per_tensor_scale=None):
1716
1717
xq = torch .empty (M , N // 2 , device = x .device , dtype = torch .uint8 )
1717
1718
return scales , xq
1718
1719
1720
+ @triton_mx_block_rearrange .register_fake
1721
+ def _ (scale_tensor ):
1722
+ rows , cols = scale_tensor .shape
1723
+ n_row_blocks = triton .cdiv (rows , 128 )
1724
+ n_col_blocks = triton .cdiv (cols , 4 )
1725
+ padded_rows = n_row_blocks * 128
1726
+ padded_cols = n_col_blocks * 4
1727
+
1728
+ return scale_tensor .new_empty ((padded_rows , padded_cols ))
1719
1729
else :
1720
1730
1721
1731
def triton_to_mxfp8_dim1 (
Original file line number Diff line number Diff line change @@ -15,7 +15,7 @@ def ceil_div(a, b):
15
15
return (a + b - 1 ) // b
16
16
17
17
18
- def to_blocked (input_matrix , use_triton_kernel : bool = True ) -> Tensor :
18
+ def to_blocked (input_matrix , use_triton_kernel : bool = False ) -> Tensor :
19
19
"""
20
20
Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern.
21
21
You can’t perform that action at this time.
0 commit comments