Skip to content

Commit 88a2851

Browse files
authored
[kernels] fix tma construction bug when M == 1 (#7148)
1 parent 7be5b8a commit 88a2851

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

python/triton_kernels/triton_kernels/matmul_ogs.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,11 +144,17 @@ def create_block_scale_descriptor(mx_tensor: torch.Tensor, block_k: int, block_n
144144
block_shape=[1, MX_SCALE_BLOCK_K,
145145
block_n], transpose=transpose)
146146

147+
@staticmethod
148+
def squeeze_after_dim(x, dim=2):
149+
shape = list(x.shape)
150+
new_shape = [s for s in shape[:dim - 1] if s != 1] + shape[dim - 1:]
151+
return x.view(*new_shape)
152+
147153
@staticmethod
148154
def create_input_descriptor_gather(x_tensor: torch.Tensor, K: int, x_stride_1: int, x_stride_2: int,
149155
block_k: int) -> TensorDescriptor:
150156
"""Create a tensor descriptor for input matrix X via TMA gather"""
151-
x_desc = x_tensor.squeeze()
157+
x_desc = TensorDescriptorBuilder.squeeze_after_dim(x_tensor)
152158
assert x_desc.ndim == 2, "TMA gather descriptor requires 2D input"
153159
INT_MAX = 2147483647
154160
return TensorDescriptor(base=x_desc, shape=[INT_MAX, K], strides=[x_stride_1, x_stride_2],
@@ -158,7 +164,7 @@ def create_input_descriptor_gather(x_tensor: torch.Tensor, K: int, x_stride_1: i
158164
def create_input_descriptor_load(x_tensor: torch.Tensor, K: int, x_stride_1: int, x_stride_2: int, block_m: int,
159165
block_k: int) -> TensorDescriptor:
160166
"""Create a tensor descriptor for input matrix X via TMA"""
161-
x_desc = x_tensor.squeeze()
167+
x_desc = TensorDescriptorBuilder.squeeze_after_dim(x_tensor)
162168
assert x_desc.ndim in [2, 3], "LHS input TMA descriptor builder expects 2D or 3D input"
163169
return TensorDescriptor(base=x_desc, shape=[x_desc.shape[0], K], strides=[x_stride_1, x_stride_2],
164170
block_shape=[block_m, block_k])

0 commit comments

Comments
 (0)