Skip to content

Commit 7708b9c

Browse files
[mxfp/easy] handle an empty tensor in downcast_to_mxfp (#7579)
Fix ``` > kernel_src_tensor = src_tensor.reshape(-1, src_tensor.shape[-1]) E RuntimeError: cannot reshape tensor of 0 elements into shape [-1, 0] because the unspecified dimension size -1 can be any value and is ambiguous ``` <!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent 9660a0d commit 7708b9c

File tree

1 file changed

+14
-13
lines changed
  • python/triton_kernels/triton_kernels/numerics_details

1 file changed

+14
-13
lines changed

python/triton_kernels/triton_kernels/numerics_details/mxfp.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -46,19 +46,20 @@ def downcast_to_mxfp(src_tensor: torch.Tensor, out_quant_type: torch.dtype, axis
4646
out_quant_tensor = src_tensor.new_empty(out_shape, dtype=out_quant_type)
4747
out_scale = src_tensor.new_empty(out_scale_shape, dtype=torch.uint8)
4848

49-
kernel_src_tensor = src_tensor.reshape(-1, src_tensor.shape[-1])
50-
kernel_quant_tensor = out_quant_tensor.view(-1, out_quant_tensor.shape[-1])
51-
kernel_scale = out_scale.view(-1, out_scale.shape[-1])
52-
53-
BLOCK_OUT_DIM = 128
54-
BLOCK_QUANT_DIM = MXFP_BLOCK_SIZE
55-
grid_out = triton.cdiv(kernel_src_tensor.shape[0], BLOCK_OUT_DIM)
56-
grid_quant = triton.cdiv(kernel_src_tensor.shape[1], BLOCK_QUANT_DIM)
57-
58-
_downcast_to_mxfp[(grid_out, grid_quant)](kernel_quant_tensor, *kernel_quant_tensor.stride(), kernel_scale,
59-
*kernel_scale.stride(), kernel_src_tensor, *kernel_src_tensor.stride(),
60-
*kernel_src_tensor.shape, BLOCK_OUT_DIM, BLOCK_QUANT_DIM,
61-
DEQUANT_SCALE_ROUNDING_MODE.value, num_warps=8)
49+
if src_tensor.numel() > 0:
50+
kernel_src_tensor = src_tensor.reshape(-1, src_tensor.shape[-1])
51+
kernel_quant_tensor = out_quant_tensor.view(-1, out_quant_tensor.shape[-1])
52+
kernel_scale = out_scale.view(-1, out_scale.shape[-1])
53+
54+
BLOCK_OUT_DIM = 128
55+
BLOCK_QUANT_DIM = MXFP_BLOCK_SIZE
56+
grid_out = triton.cdiv(kernel_src_tensor.shape[0], BLOCK_OUT_DIM)
57+
grid_quant = triton.cdiv(kernel_src_tensor.shape[1], BLOCK_QUANT_DIM)
58+
59+
_downcast_to_mxfp[(grid_out, grid_quant)](kernel_quant_tensor, *kernel_quant_tensor.stride(), kernel_scale,
60+
*kernel_scale.stride(), kernel_src_tensor, *kernel_src_tensor.stride(),
61+
*kernel_src_tensor.shape, BLOCK_OUT_DIM, BLOCK_QUANT_DIM,
62+
DEQUANT_SCALE_ROUNDING_MODE.value, num_warps=8)
6263

6364
out_quant_tensor = out_quant_tensor.transpose(axis, src_tensor.ndim - 1)
6465
out_scale = out_scale.transpose(axis, src_tensor.ndim - 1)

0 commit comments

Comments
 (0)