@@ -367,18 +367,13 @@ def can_use_fused_scatter(scatter_indx, fused_activation):
367
367
368
368
@dataclass (frozen = True )
369
369
class PreprocessingFeatures :
370
- w_want_n_major : bool
371
370
w_want_k_major : bool
372
371
swap_xw : bool
373
372
374
- def __post_init__ (self ):
375
- assert not (self .w_want_k_major and self .w_want_n_major ), "Cannot have both K-major and N-major"
376
-
377
373
def init_preprocessing_features (w , precision_config , opt_flags ):
378
374
mx_ctx = precision_config .mx_ctx
379
375
swap_xw = False # Whether or not to swap X and W operands to the tl.dot
380
376
w_want_k_major = False
381
- w_want_n_major = False
382
377
if not target_info .cuda_capability_geq (10 , 0 ):
383
378
# Hopper transpose. Reduction dimension must be contiguous.
384
379
if w .stride (1 ) != 1 and w .dtype .itemsize == 1 :
@@ -388,12 +383,7 @@ def init_preprocessing_features(w, precision_config, opt_flags):
388
383
swap_xw = mx_ctx .weight_scale is not None and opt_flags .block_m <= 64 and opt_flags .is_persistent
389
384
if swap_xw :
390
385
w_want_k_major = True
391
- # fp4 padded mode requires the contiguous dim size to be a multiple of 64 bytes. If it is K-major and does not
392
- # meet the requirement, make the tensor N-major instead.
393
- # But, don't do this if we're going to swap X and W in which case we would transpose W again.
394
- if w .stride (1 ) == 1 and w .dtype == torch .uint8 and w .shape [1 ] % 64 != 0 and not swap_xw :
395
- w_want_n_major = True
396
- return PreprocessingFeatures (w_want_n_major , w_want_k_major , swap_xw )
386
+ return PreprocessingFeatures (w_want_k_major , swap_xw )
397
387
398
388
399
389
def apply_preprocessing_features (x , w , gather_indx , scatter_indx , routing_data , opt_flags , preprocessing_features ):
@@ -420,10 +410,7 @@ def apply_preprocessing_features(x, w, gather_indx, scatter_indx, routing_data,
420
410
finalize_scatter_idxs = None
421
411
else :
422
412
writeback_idxs , writeback_size , finalize_scatter_idxs = None , None , None
423
- # some transposition variants aren't supported
424
- if preprocessing_features .w_want_n_major :
425
- w = fast_contiguous (w )
426
- elif preprocessing_features .w_want_k_major :
413
+ if preprocessing_features .w_want_k_major :
427
414
w = fast_contiguous (w .transpose (- 1 , - 2 )).transpose (- 1 , - 2 )
428
415
# preprocess routing information and ptr lookup table
429
416
M = x .shape [1 ] if gather_indx is None else gather_indx .src_indx .shape [0 ]
0 commit comments