Skip to content

Commit 9f12f14

Browse files
authored
fix swizzle kernel w/ fullgraph (#2705)
stack-info: PR: #2705, branch: drisspg/stack/86
1 parent 785f3dd commit 9f12f14

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

torchao/prototype/mx_formats/kernels.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1448,6 +1448,7 @@ def triton_scale_swizzle(
14481448
scales_flat,
14491449
)
14501450

1451+
@torch.library.custom_op("torchao::triton_mx_block_rearrange", mutates_args=())
14511452
def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
14521453
"""
14531454
Rearranges an E8M0 tensor scale from row-major format to block-scaled swizzle format.
@@ -1716,6 +1717,15 @@ def _(x, per_tensor_scale=None):
17161717
xq = torch.empty(M, N // 2, device=x.device, dtype=torch.uint8)
17171718
return scales, xq
17181719

1720+
@triton_mx_block_rearrange.register_fake
1721+
def _(scale_tensor):
1722+
rows, cols = scale_tensor.shape
1723+
n_row_blocks = triton.cdiv(rows, 128)
1724+
n_col_blocks = triton.cdiv(cols, 4)
1725+
padded_rows = n_row_blocks * 128
1726+
padded_cols = n_col_blocks * 4
1727+
1728+
return scale_tensor.new_empty((padded_rows, padded_cols))
17191729
else:
17201730

17211731
def triton_to_mxfp8_dim1(

torchao/prototype/mx_formats/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def ceil_div(a, b):
1515
return (a + b - 1) // b
1616

1717

18-
def to_blocked(input_matrix, use_triton_kernel: bool = True) -> Tensor:
18+
def to_blocked(input_matrix, use_triton_kernel: bool = False) -> Tensor:
1919
"""
2020
Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern.
2121

0 commit comments

Comments
 (0)