Skip to content

Commit 6da35a7

Browse files
authored
[KERNELS] fixes a couple of issues in matmul_ogs.py (#7177)
1 parent e3d0ec9 commit 6da35a7

File tree

3 files changed

+17
-14
lines changed

3 files changed

+17
-14
lines changed

python/triton_kernels/triton_kernels/matmul_ogs.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -584,8 +584,7 @@ def apply_allocation(allocation: MatmulAllocation, output):
584584

585585
def _create_tma_descriptors(
586586
x: torch.Tensor,
587-
x_tensor: torch.Tensor,
588-
w_tensor: torch.Tensor,
587+
w: torch.Tensor,
589588
mx_tensor: Optional[torch.Tensor],
590589
routing_data: RoutingData,
591590
mx_ctx: MicroscalingCtx,
@@ -611,21 +610,21 @@ def _create_tma_descriptors(
611610
if (use_host_tma_descriptors):
612611
if USE_GATHER_TMA or X_USE_LOAD_TMA:
613612
x_desc = TensorDescriptorBuilder.create_input_descriptor(
614-
x_tensor, K, x.stride(1), x.stride(2),
613+
x, K, x.stride(1), x.stride(2),
615614
opt_flags.block_k, opt_flags.block_m,
616615
USE_GATHER_TMA, X_USE_LOAD_TMA
617616
)
618617
descriptors.append(x_desc)
619618
if (expt_data is not None and len(expt_data.block_pid_map) > 0):
620619
w_desc = TensorDescriptorBuilder.create_weight_descriptor(
621-
w_tensor, opt_flags.block_k, opt_flags.block_n, w_transpose
620+
w, opt_flags.block_k, opt_flags.block_n, w_transpose
622621
)
623-
is_microscaled_format = (mx_ctx.weight_scale is not None) and (w_tensor.dtype == torch.uint8)
622+
is_microscaled_format = (mx_ctx.weight_scale is not None) and (w.dtype == torch.uint8)
624623
if is_microscaled_format:
625624
# Pad the inner shape to 128 for mxfp4 weights
626625
# for mixed precision fp8 x mxfp4 compute
627626
pad = 128
628-
dim_to_pad = -1 if w_transpose else -2
627+
dim_to_pad = -1
629628
old_size = w_desc.shape[dim_to_pad]
630629
padded_size = math.ceil(old_size / pad) * pad
631630
if padded_size != old_size:
@@ -645,7 +644,7 @@ def _create_tma_descriptors(
645644
# TODO: Currently all or none, instead should support a mixture
646645
# of host and device descriptors
647646
if None in descriptors or len(descriptors) == 0:
648-
descriptors = [x_tensor, w_tensor, mx_tensor]
647+
descriptors = [x, w, mx_tensor]
649648
use_host_tma_descriptors = False
650649
if opt_flags.is_persistent:
651650
opt_flags.target_kernel_kwargs["USE_HOST_TMA_DESCRIPTORS"] = use_host_tma_descriptors
@@ -759,9 +758,7 @@ def matmul_ogs(x, w, bias,
759758
USE_GATHER_TMA = HAS_TMA_GS and gather_indx is not None
760759
X_USE_LOAD_TMA = gather_indx is None and not USE_GATHER_TMA
761760
_, x_tensor, w_tensor, mx_tensor = _create_tma_descriptors(
762-
x=x,
763-
x_tensor=flex.lhs_data.reinterpret(x),
764-
w_tensor=flex.rhs_data.reinterpret(w),
761+
x=x, w=w,
765762
mx_tensor=mx_ctx.weight_scale,
766763
routing_data=routing_data,
767764
mx_ctx=mx_ctx,
@@ -777,7 +774,10 @@ def matmul_ogs(x, w, bias,
777774
w_transpose=w.stride(2) != 1,
778775
mx_transpose=mx_scale_stride_n != 1,
779776
)
780-
777+
if isinstance(x_tensor, torch.Tensor):
778+
x_tensor = flex.lhs_data.reinterpret(x)
779+
if isinstance(w_tensor, torch.Tensor):
780+
w_tensor = flex.rhs_data.reinterpret(w)
781781
(kernels._p_matmul_ogs if opt_flags.is_persistent else kernels._matmul_ogs)[(n_cta,)](
782782
flex.out_data.reinterpret(memory["output"]),
783783
flex.out_data.reinterpret(out0), *out0.stride(),

python/triton_kernels/triton_kernels/matmul_ogs_details/_common.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,14 @@ def matmul_repr(specialization):
5252

5353
def convert_dtype(dtype):
5454
if "tensordesc" in dtype:
55-
return dtype.split("<")[1].split("[")[0]
55+
ret = convert_dtype(dtype.split("<")[1].split("[")[0])
56+
return ret
5657
elif "u8" in dtype:
5758
return "mxfp4"
58-
else:
59+
elif dtype[0] == "*":
5960
return dtype[1:]
61+
else:
62+
return dtype
6063

6164
dtypes = "x".join([convert_dtype(f"{signature[i]}") for i in reorder(["Y", "X", "W"])])
6265
layouts = "".join([f"{layout(i)}" for i in reorder(["stride_y_n", "stride_x_k", "stride_w_n"])])

python/triton_kernels/triton_kernels/matmul_ogs_details/_finalize_matmul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ def _finalize_matmul(
291291
if src_idx != -1:
292292
As = A + src_idx.to(tl.int64) * stride_a_m + offs_n
293293
for ki in tl.static_range(K):
294-
acc += tl.load(As, mask=(src_idxs != -1)[:, None] & n_mask[None, :], other=0.0)
294+
acc += tl.load(As, mask=n_mask, other=0.0)
295295
As += stride_a_k
296296
else:
297297
As = A + src_idxs.to(tl.int64)[:, None] * stride_a_m + offs_n[None, :]

0 commit comments

Comments
 (0)