@@ -176,8 +176,6 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
176176
177177 :param sf_vec_size: Scalefactor vector size (16 for NVF4, 32 for MXF4/MXF8).
178178 :type sf_vec_size: int
179- :param acc_dtype: Data type of the accumulator (e.g., cutlass.Float32).
180- :type acc_dtype: Type[cutlass.Numeric]
181179 :param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N).
182180 Note: use_2cta_instrs is automatically inferred from mma_tiler_mn[0]
183181 (True when M=256, False when M=128).
@@ -217,7 +215,6 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
217215 >>> # (True when M=256, False when M=128)
218216 >>> gemm = BlockScaledContiguousGatherGroupedGemmKernel(
219217 ... sf_vec_size=16,
220- ... acc_dtype=cutlass.Float32,
221218 ... mma_tiler_mn=(256, 128), # use_2cta_instrs=True since M=256
222219 ... cluster_shape_mn=(2, 1),
223220 ... vectorized_f32=True,
@@ -243,7 +240,6 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
243240 def __init__ (
244241 self ,
245242 sf_vec_size : int ,
246- acc_dtype : Type [cutlass .Numeric ],
247243 mma_tiler_mn : Tuple [int , int ],
248244 cluster_shape_mn : Tuple [int , int ],
249245 vectorized_f32 : bool ,
@@ -274,8 +270,6 @@ def __init__(
274270
275271 :param sf_vec_size: Vector size for scale factors (16 for NVF4, 32 for MXF4/MXF8).
276272 :type sf_vec_size: int
277- :param acc_dtype: Data type of the accumulator.
278- :type acc_dtype: type[cutlass.Numeric]
279273 :param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction.
280274 use_2cta_instrs is automatically set based on M (True if M=256, False if M=128).
281275 :type mma_tiler_mn: Tuple[int, int]
@@ -289,7 +283,7 @@ def __init__(
289283
290284 self .sf_vec_size = sf_vec_size
291285 self .topk = topk
292- self .acc_dtype : Type [ cutlass . Numeric ] = acc_dtype
286+ self .acc_dtype = cutlass . Float32
293287 self .use_2cta_instrs = mma_tiler_mn [0 ] == 256
294288 self .cluster_shape_mn = cluster_shape_mn
295289 # K dimension is deferred in _setup_attributes
@@ -2620,7 +2614,6 @@ def is_valid_dtypes_and_scale_factor_vec_size(
26202614 ab_dtype : Type [cutlass .Numeric ],
26212615 sf_dtype : Type [cutlass .Numeric ],
26222616 sf_vec_size : int ,
2623- acc_dtype : Type [cutlass .Numeric ],
26242617 c_dtype : Type [cutlass .Numeric ],
26252618 ) -> bool :
26262619 """
@@ -2632,8 +2625,6 @@ def is_valid_dtypes_and_scale_factor_vec_size(
26322625 :type sf_dtype: Type[cutlass.Numeric]
26332626 :param sf_vec_size: The vector size of the scale factor
26342627 :type sf_vec_size: int
2635- :param acc_dtype: The data type of the accumulator
2636- :type acc_dtype: Type[cutlass.Numeric]
26372628 :param c_dtype: The data type of the output tensor
26382629 :type c_dtype: Type[cutlass.Numeric]
26392630
@@ -2662,8 +2653,6 @@ def is_valid_dtypes_and_scale_factor_vec_size(
26622653 if ab_dtype in {cutlass .Float8E5M2 , cutlass .Float8E4M3FN } and sf_vec_size == 16 :
26632654 is_valid = False
26642655
2665- if acc_dtype not in {cutlass .Float32 }:
2666- is_valid = False
26672656 # Check valid c_dtype
26682657 if c_dtype not in {
26692658 cutlass .Float32 ,
@@ -2717,7 +2706,6 @@ def is_valid_mma_tiler_and_cluster_shape(
27172706 use_2cta_instrs : bool ,
27182707 mma_tiler_mn : Tuple [int , int ],
27192708 cluster_shape_mn : Tuple [int , int ],
2720- m_aligned : cutlass .Int64 ,
27212709 ) -> bool :
27222710 """
27232711 Check if the mma tiler and cluster shape are valid
@@ -2728,8 +2716,6 @@ def is_valid_mma_tiler_and_cluster_shape(
27282716 :type mma_tiler_mn: Tuple[int, int]
27292717 :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster
27302718 :type cluster_shape_mn: Tuple[int, int]
2731- :param m_aligned: The alignment requirement for group M dimension (default: 128)
2732- :type m_aligned: cutlass.Int64
27332719
27342720 :return: True if the mma tiler and cluster shape are valid, False otherwise
27352721 :rtype: bool
@@ -2771,13 +2757,6 @@ def is_valid_mma_tiler_and_cluster_shape(
27712757 if cluster_tiler_m not in [64 , 128 , 256 ]:
27722758 is_valid = False
27732759
2774- # Check if m_aligned is a multiple of cluster_tiler_m
2775- # This ensures that each group's M dimension (which is a multiple of m_aligned)
2776- # won't be split across tiles, preventing a single tile from loading data
2777- # from multiple groups (which would access wrong B matrix data)
2778- if m_aligned % mma_tiler_mn [0 ] != 0 :
2779- is_valid = False
2780-
27812760 return is_valid
27822761
27832762 @staticmethod
@@ -2838,7 +2817,6 @@ def can_implement(
28382817 ab_dtype : Type [cutlass .Numeric ],
28392818 sf_dtype : Type [cutlass .Numeric ],
28402819 sf_vec_size : int ,
2841- acc_dtype : Type [cutlass .Numeric ],
28422820 c_dtype : Type [cutlass .Numeric ],
28432821 mma_tiler_mn : Tuple [int , int ],
28442822 cluster_shape_mn : Tuple [int , int ],
@@ -2849,7 +2827,6 @@ def can_implement(
28492827 a_major : str ,
28502828 b_major : str ,
28512829 c_major : str ,
2852- m_aligned : cutlass .Int64 ,
28532830 ) -> bool :
28542831 """
28552832 Check if the gemm can be implemented
@@ -2860,8 +2837,6 @@ def can_implement(
28602837 :type sf_dtype: Type[cutlass.Numeric]
28612838 :param sf_vec_size: The vector size of the scale factor
28622839 :type sf_vec_size: int
2863- :param acc_dtype: The data type of the accumulator
2864- :type acc_dtype: Type[cutlass.Numeric]
28652840 :param c_dtype: The data type of the output tensor
28662841 :type c_dtype: Type[cutlass.Numeric]
28672842 :param use_2cta_instrs: Whether to use 2 CTA groups
@@ -2884,16 +2859,14 @@ def can_implement(
28842859 :type b_major: str
28852860 :param c_major: The major axis of the C tensor
28862861 :type c_major: str
2887- :param m_aligned: The alignment requirement for group M dimension (default: 128)
2888- :type m_aligned: cutlass.Int64
28892862
28902863 :return: True if the gemm can be implemented, False otherwise
28912864 :rtype: bool
28922865 """
28932866 can_implement = True
28942867 # Skip unsupported types
28952868 if not BlockScaledContiguousGatherGroupedGemmKernel .is_valid_dtypes_and_scale_factor_vec_size (
2896- ab_dtype , sf_dtype , sf_vec_size , acc_dtype , c_dtype
2869+ ab_dtype , sf_dtype , sf_vec_size , c_dtype
28972870 ):
28982871 can_implement = False
28992872
@@ -2903,10 +2876,10 @@ def can_implement(
29032876 ):
29042877 can_implement = False
29052878
2906- use_2cta_instrs = mma_tiler_mn [0 ] == 256
29072879 # Skip invalid mma tile shape and cluster shape
2880+ use_2cta_instrs = mma_tiler_mn [0 ] == 256
29082881 if not BlockScaledContiguousGatherGroupedGemmKernel .is_valid_mma_tiler_and_cluster_shape (
2909- use_2cta_instrs , mma_tiler_mn , cluster_shape_mn , m_aligned
2882+ use_2cta_instrs , mma_tiler_mn , cluster_shape_mn
29102883 ):
29112884 can_implement = False
29122885 # Skip illegal problem shape for load/store alignment
0 commit comments