@@ -80,12 +80,14 @@ def get_valid_tactics(
8080
8181 # full shamoo
8282 mma_tiler_mn_candidates = [
83- (256 , 128 ),
83+ (128 , 64 ),
84+ (256 , 64 ),
8485 (128 , 128 ),
86+ (256 , 128 ),
87+ (128 , 192 ),
88+ (256 , 192 ),
8589 (128 , 256 ),
8690 (256 , 256 ),
87- (256 , 64 ),
88- (128 , 64 ),
8991 ]
9092 cluster_shape_mn_candidates = [
9193 (1 , 1 ),
@@ -99,37 +101,38 @@ def get_valid_tactics(
99101 (4 , 4 ),
100102 ]
101103 swap_ab_candidates = [True , False ]
104+ use_prefetch_candidates = [True , False ]
102105
103106 valid_tactics = []
104- for swap_ab in swap_ab_candidates :
105- for mma_tiler_mn in mma_tiler_mn_candidates :
106- for cluster_shape_mn in cluster_shape_mn_candidates :
107- if swap_ab :
108- c_major = "m"
109- kernel_m = n
110- kernel_n = m
111- else :
112- c_major = "n"
113- kernel_m = m
114- kernel_n = n
115-
116- if self .__class__ .kernel_class .can_implement (
117- cutlass .Float4E2M1FN , # ab_dtype,
118- cutlass .Float8E4M3FN , # sf_dtype
119- sf_vec_size , # sf_vec_size,
120- cutlass .BFloat16 , # c_dtype,
121- mma_tiler_mn ,
122- cluster_shape_mn ,
123- kernel_m ,
124- kernel_n ,
125- real_k ,
126- batch_size ,
127- a_major ,
128- b_major ,
129- c_major ,
130- ):
131- valid_tactics .append (
132- (mma_tiler_mn , cluster_shape_mn , swap_ab ))
107+ for mma_tiler_mn , cluster_shape_mn , swap_ab , use_prefetch in itertools . product (
108+ mma_tiler_mn_candidates , cluster_shape_mn_candidates ,
109+ swap_ab_candidates , use_prefetch_candidates ) :
110+ if swap_ab :
111+ c_major = "m"
112+ kernel_m = n
113+ kernel_n = m
114+ else :
115+ c_major = "n"
116+ kernel_m = m
117+ kernel_n = n
118+
119+ if self .__class__ .kernel_class .can_implement (
120+ cutlass .Float4E2M1FN , # ab_dtype,
121+ cutlass .Float8E4M3FN , # sf_dtype
122+ sf_vec_size , # sf_vec_size,
123+ cutlass .BFloat16 , # c_dtype,
124+ mma_tiler_mn ,
125+ cluster_shape_mn ,
126+ kernel_m ,
127+ kernel_n ,
128+ real_k ,
129+ batch_size ,
130+ a_major ,
131+ b_major ,
132+ c_major ,
133+ ):
134+ valid_tactics .append (
135+ (mma_tiler_mn , cluster_shape_mn , swap_ab , use_prefetch ))
133136
134137 return valid_tactics
135138
@@ -158,21 +161,22 @@ def forward(
158161 inputs[3]: Weight scale tensor of shape (n, k//16), dtype: fp8.
159162 inputs[4]: Alpha scaling factor. dtype: float32.
160163 inputs[5]: Output dtype, expected to be torch.bfloat16.
161- tactic: Tiling and cluster strategy, typically a tuple (mma_tiler_mn, cluster_shape_mn).
164+ tactic: Tiling and cluster strategy, typically a tuple (mma_tiler_mn, cluster_shape_mn, swap_ab, use_prefetch ).
162165
163166 Returns:
164167 torch.Tensor: Output tensor of shape (m, n), dtype: bf16.
165168 """
166169 sf_vec_size = 16
167170
168171 if isinstance (tactic , tuple ):
169- mma_tiler_mn , cluster_shape_mn , swap_ab = tactic
172+ mma_tiler_mn , cluster_shape_mn , swap_ab , use_prefetch = tactic
170173 else :
171174 # fallback to default tactic
172- mma_tiler_mn , cluster_shape_mn , swap_ab = [
175+ mma_tiler_mn , cluster_shape_mn , swap_ab , use_prefetch = [
173176 (128 , 128 ),
174177 (1 , 1 ),
175178 False ,
179+ False ,
176180 ]
177181
178182 a_tensor , b_tensor , a_sf_tensor , b_sf_tensor = inputs
@@ -208,7 +212,8 @@ def forward(
208212 torch_stream = torch .cuda .current_stream ()
209213 stream = cuda .CUstream (torch_stream .cuda_stream )
210214
211- cache_key = (sf_vec_size , mma_tiler_mn , cluster_shape_mn , swap_ab )
215+ cache_key = (sf_vec_size , mma_tiler_mn , cluster_shape_mn , swap_ab ,
216+ use_prefetch )
212217 if swap_ab :
213218 kernel_a_ptr = b_ptr
214219 kernel_a_sf_ptr = b_sf_ptr
@@ -233,6 +238,7 @@ def forward(
233238 sf_vec_size ,
234239 mma_tiler_mn ,
235240 cluster_shape_mn ,
241+ use_prefetch ,
236242 )
237243 # Compute max active clusters on current device
238244 hardware_info = cutlass .utils .HardwareInfo ()
@@ -257,6 +263,7 @@ def forward(
257263 max_active_clusters ,
258264 stream ,
259265 swap_ab ,
266+ options = f"--opt-level 2" ,
260267 )
261268
262269 self .__class__ .kernel_cache [cache_key ] = compiled_gemm
0 commit comments