Skip to content

Commit bb1000f

Browse files
authored
[kernels] remove want_n_major condition in persistent matmul (#7167)
fixes in triton-lang/triton#7119 cover this condition in a way that doesn't require transposition by padding the shape of the TMA descriptor
1 parent 6d5fb9f commit bb1000f

File tree

1 file changed

+2
-15
lines changed

1 file changed

+2
-15
lines changed

python/triton_kernels/triton_kernels/matmul_ogs.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -367,18 +367,13 @@ def can_use_fused_scatter(scatter_indx, fused_activation):
367367

368368
@dataclass(frozen=True)
369369
class PreprocessingFeatures:
370-
w_want_n_major: bool
371370
w_want_k_major: bool
372371
swap_xw: bool
373372

374-
def __post_init__(self):
375-
assert not (self.w_want_k_major and self.w_want_n_major), "Cannot have both K-major and N-major"
376-
377373
def init_preprocessing_features(w, precision_config, opt_flags):
378374
mx_ctx = precision_config.mx_ctx
379375
swap_xw = False # Whether or not to swap X and W operands to the tl.dot
380376
w_want_k_major = False
381-
w_want_n_major = False
382377
if not target_info.cuda_capability_geq(10, 0):
383378
# Hopper transpose. Reduction dimension must be contiguous.
384379
if w.stride(1) != 1 and w.dtype.itemsize == 1:
@@ -388,12 +383,7 @@ def init_preprocessing_features(w, precision_config, opt_flags):
388383
swap_xw = mx_ctx.weight_scale is not None and opt_flags.block_m <= 64 and opt_flags.is_persistent
389384
if swap_xw:
390385
w_want_k_major = True
391-
# fp4 padded mode requires the contiguous dim size to be a multiple of 64 bytes. If it is K-major and does not
392-
# meet the requirement, make the tensor N-major instead.
393-
# But, don't do this if we're going to swap X and W in which case we would transpose W again.
394-
if w.stride(1) == 1 and w.dtype == torch.uint8 and w.shape[1] % 64 != 0 and not swap_xw:
395-
w_want_n_major = True
396-
return PreprocessingFeatures(w_want_n_major, w_want_k_major, swap_xw)
386+
return PreprocessingFeatures(w_want_k_major, swap_xw)
397387

398388

399389
def apply_preprocessing_features(x, w, gather_indx, scatter_indx, routing_data, opt_flags, preprocessing_features):
@@ -420,10 +410,7 @@ def apply_preprocessing_features(x, w, gather_indx, scatter_indx, routing_data,
420410
finalize_scatter_idxs = None
421411
else:
422412
writeback_idxs, writeback_size, finalize_scatter_idxs = None, None, None
423-
# some transposition variants aren't supported
424-
if preprocessing_features.w_want_n_major:
425-
w = fast_contiguous(w)
426-
elif preprocessing_features.w_want_k_major:
413+
if preprocessing_features.w_want_k_major:
427414
w = fast_contiguous(w.transpose(-1, -2)).transpose(-1, -2)
428415
# preprocess routing information and ptr lookup table
429416
M = x.shape[1] if gather_indx is None else gather_indx.src_indx.shape[0]

0 commit comments

Comments
 (0)