@@ -65,8 +65,7 @@ def grouped_mm_configs():
6565 return _NV_CONFIGS
6666
6767
68- def early_config_prune (g , m , configs , named_args ):
69- dtsize = 1
68+ def early_config_prune (g , m , dtsize , configs , named_args ):
7069 pruned_configs = []
7170 for config in configs :
7271 kw = config .kwargs
@@ -186,15 +185,29 @@ def early_config_prune(g, m, configs, named_args):
186185{%- endif %}
187186 a_ptr,
188187{%- if A_IS_2D %}
188+ {%- if A_IS_K_MAJOR %}
189189 shape=[M, K],
190190 # fixme: strides=[A_STRIDE_M, A_STRIDE_K],
191191 strides=[{{stride("a_ptr", -2)}}, {{stride("a_ptr", -1)}}],
192192 block_shape=[BLOCK_M, BLOCK_K],
193193{%- else %}
194+ shape=[K, M],
195+ # fixme: strides=[A_STRIDE_K, A_STRIDE_M],
196+ strides=[{{stride("a_ptr", -1)}}, {{stride("a_ptr", -2)}}],
197+ block_shape=[BLOCK_K, BLOCK_M],
198+ {%- endif %}
199+ {%- else %}
200+ {%- if A_IS_K_MAJOR %}
194201 shape=[G, M, K],
195202 # fixme: strides=[A_STRIDE_G, A_STRIDE_M, A_STRIDE_K],
196203 strides=[{{stride("a_ptr", 0)}}, {{stride("a_ptr", -2)}}, {{stride("a_ptr", -1)}}],
197204 block_shape=[1, BLOCK_M, BLOCK_K],
205+ {%- else %}
206+ shape=[G, K, M],
207+ # fixme: strides=[A_STRIDE_G, A_STRIDE_K, A_STRIDE_M],
208+ strides=[{{stride("a_ptr", 0)}}, {{stride("a_ptr", -1)}}, {{stride("a_ptr", -2)}}],
209+ block_shape=[1, BLOCK_K, BLOCK_M],
210+ {%- endif %}
198211{%- endif %}
199212 )
200213
@@ -205,15 +218,29 @@ def early_config_prune(g, m, configs, named_args):
205218{%- endif %}
206219 b_ptr,
207220{%- if B_IS_2D %}
221+ {%- if B_IS_K_MAJOR %}
208222 shape=[N, K],
209223 # fixme: strides=[B_STRIDE_N, B_STRIDE_K],
210224 strides=[{{stride("b_ptr", -1)}}, {{stride("b_ptr", -2)}}],
211225 block_shape=[BLOCK_N, BLOCK_K],
212226{%- else %}
227+ shape=[K, N],
228+ # fixme: strides=[B_STRIDE_K, B_STRIDE_N],
229+ strides=[{{stride("b_ptr", -2)}}, {{stride("b_ptr", -1)}}],
230+ block_shape=[BLOCK_K, BLOCK_N],
231+ {%- endif %}
232+ {%- else %}
233+ {%- if B_IS_K_MAJOR %}
213234 shape=[G, N, K],
214235 # fixme: strides=[B_STRIDE_G, B_STRIDE_N, B_STRIDE_K],
215236 strides=[{{stride("b_ptr", 0)}}, {{stride("b_ptr", -1)}}, {{stride("b_ptr", -2)}}],
216237 block_shape=[1, BLOCK_N, BLOCK_K],
238+ {%- else %}
239+ shape=[G, K, N],
240+ # fixme: strides=[B_STRIDE_G, B_STRIDE_K, B_STRIDE_N],
241+ strides=[{{stride("b_ptr", 0)}}, {{stride("b_ptr", -2)}}, {{stride("b_ptr", -1)}}],
242+ block_shape=[1, BLOCK_K, BLOCK_N],
243+ {%- endif %}
217244{%- endif %}
218245 )
219246{%- endif %}
@@ -286,39 +313,82 @@ def early_config_prune(g, m, configs, named_args):
286313 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
287314
288315{%- if USE_TMA_LOAD %}
289- m_offset = (m_start_offset + tile_m_idx * BLOCK_M).to(tl.int32)
290- n_offset = (n_start_offset + tile_n_idx * BLOCK_N).to(tl.int32)
316+ m_tile_offset = tile_m_idx * BLOCK_M
317+ n_tile_offset = tile_n_idx * BLOCK_N
318+ m_offset = (m_start_offset + m_tile_offset).to(tl.int32)
319+ n_offset = (n_start_offset + n_tile_offset).to(tl.int32)
291320
292- for k_offset in range(0, k_size, BLOCK_K):
321+ for k_block_offset in range(0, k_size, BLOCK_K):
293322{%- if A_IS_2D %}
294- a = a_desc.load([m_offset, k_start_offset + k_offset])
323+ {%- if A_IS_K_MAJOR %}
324+ a = a_desc.load([m_offset, k_start_offset + k_block_offset])
325+ {%- else %}
326+ a = a_desc.load([k_start_offset + k_block_offset, m_offset])
327+ {%- endif %}
328+ {%- else %}
329+ {%- if A_IS_K_MAJOR %}
330+ a = a_desc.load([g, m_offset, k_start_offset + k_block_offset]).reshape(BLOCK_M, BLOCK_K)
295331{%- else %}
296- a = a_desc.load([g, m_offset, k_start_offset + k_offset]).reshape(BLOCK_M, BLOCK_K)
332+ a = a_desc.load([g, k_start_offset + k_block_offset, m_offset]).reshape(BLOCK_K, BLOCK_M)
333+ {%- endif %}
297334{%- endif %}
298335{%- if B_IS_2D %}
299- b = b_desc.load([n_offset, k_start_offset + k_offset])
336+ {%- if B_IS_K_MAJOR %}
337+ b = b_desc.load([n_offset, k_start_offset + k_block_offset])
300338{%- else %}
301- b = b_desc.load([g, n_offset, k_start_offset + k_offset]).reshape(BLOCK_N, BLOCK_K)
339+ b = b_desc.load([k_start_offset + k_block_offset, n_offset])
340+ {%- endif %}
341+ {%- else %}
342+ {%- if B_IS_K_MAJOR %}
343+ b = b_desc.load([g, n_offset, k_start_offset + k_block_offset]).reshape(BLOCK_N, BLOCK_K)
344+ {%- else %}
345+ b = b_desc.load([g, k_start_offset + k_block_offset, n_offset]).reshape(BLOCK_K, BLOCK_N)
346+ {%- endif %}
302347{%- endif %}
303348
304349{%- if K_IS_VARYING %}
305- if k_offset + BLOCK_K > k_size:
306- group_offs_k = k_offset + tl.arange(0, BLOCK_K)
307- a = tl.where(group_offs_k < k_size, a, 0)
308- b = tl.where(group_offs_k < k_size, b, 0)
350+ if k_block_offset + BLOCK_K > k_size:
351+ group_offs = k_block_offset + tl.arange(0, BLOCK_K)
352+ k_mask = group_offs < k_size
353+ {%- if A_IS_K_MAJOR %}
354+ a = tl.where(k_mask[None, :], a, 0)
355+ {%- else %}
356+ a = tl.where(k_mask[:, None], a, 0)
357+ {%- endif %}
358+ {%- if B_IS_K_MAJOR %}
359+ b = tl.where(k_mask[None, :], b, 0)
360+ {%- else %}
361+ b = tl.where(k_mask[:, None], b, 0)
362+ {%- endif %}
309363{%- endif %}
310364
311365{%- if USE_FAST_ACCUM %}
366+ {%- if A_IS_K_MAJOR and B_IS_K_MAJOR %}
312367 accumulator = tl.dot(a, b.T, accumulator)
368+ {%- elif A_IS_K_MAJOR and not B_IS_K_MAJOR %}
369+ accumulator = tl.dot(a, b, accumulator)
370+ {%- elif not A_IS_K_MAJOR and B_IS_K_MAJOR %}
371+ accumulator = tl.dot(a.T, b.T, accumulator)
313372{%- else %}
373+ accumulator = tl.dot(a.T, b, accumulator)
374+ {%- endif %}
375+ {%- else %}
376+ {%- if A_IS_K_MAJOR and B_IS_K_MAJOR %}
314377 accumulator += tl.dot(a, b.T)
378+ {%- elif A_IS_K_MAJOR and not B_IS_K_MAJOR %}
379+ accumulator += tl.dot(a, b)
380+ {%- elif not A_IS_K_MAJOR and B_IS_K_MAJOR %}
381+ accumulator += tl.dot(a.T, b.T)
382+ {%- else %}
383+ accumulator += tl.dot(a.T, b)
384+ {%- endif %}
315385{%- endif %}
316386{%- else %}
317387 offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M)
318388 offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N)
319- for k_offset in range(0, k_size, BLOCK_K):
320- group_offs_k = k_offset + tl.arange(0, BLOCK_K)
321- offs_k = group_offs_k + k_start_offset
389+ for k_block_offset in range(0, k_size, BLOCK_K):
390+ block_offs_k = k_block_offset + tl.arange(0, BLOCK_K)
391+ offs_k = block_offs_k + k_start_offset
322392 a_ptrs = (
323393 a_ptr
324394{%- if not A_IS_2D %}
@@ -335,10 +405,10 @@ def early_config_prune(g, m, configs, named_args):
335405 + (n_start_offset + offs_bn[:, None]) * B_STRIDE_N
336406 + offs_k[None, :] * B_STRIDE_K
337407 )
338- a_mask = (offs_am[:, None] < m_size) & (group_offs_k [None, :] < k_size)
339- b_mask = (offs_bn[:, None] < n_size) & (group_offs_k [None, :] < k_size)
340- a = tl.load(a_ptrs, mask=a_mask, other=0 )
341- b = tl.load(b_ptrs, mask=b_mask, other=0 )
408+ a_mask = (offs_am[:, None] < m_size) & (block_offs_k [None, :] < k_size)
409+ b_mask = (offs_bn[:, None] < n_size) & (block_offs_k [None, :] < k_size)
410+ a = tl.load(a_ptrs, mask=a_mask, other=tl.zeros((), dtype=a_ptrs.dtype.element_ty) )
411+ b = tl.load(b_ptrs, mask=b_mask, other=tl.zeros((), dtype=b_ptrs.dtype.element_ty) )
342412{%- if USE_FAST_ACCUM %}
343413 accumulator = tl.dot(a, b.T, accumulator)
344414{%- else %}
@@ -360,7 +430,7 @@ def early_config_prune(g, m, configs, named_args):
360430{%- endif %}
361431 + offs_am[:, None],
362432 mask=offs_am[:, None] < m_size,
363- other=0 ,
433+ other=tl.zeros((), dtype=scale_a_ptr.dtype.element_ty) ,
364434 )
365435 scale_b = tl.load(
366436 scale_b_ptr
@@ -371,7 +441,7 @@ def early_config_prune(g, m, configs, named_args):
371441{%- endif %}
372442 + offs_bn[None, :],
373443 mask=offs_bn[None, :] < n_size,
374- other=0 ,
444+ other=tl.zeros((), dtype=scale_b_ptr.dtype.element_ty) ,
375445 )
376446 c = accumulator.to(tl.float32) * scale_a * scale_b
377447{%- else %}
@@ -648,6 +718,9 @@ def _tuned_grouped_mm_common(
648718 V .graph .sizevars .check_equals (k1 , k2 )
649719 a_is_2d , b_is_2d = False , False
650720
721+ a_is_k_major = mat_a .get_stride ()[- 1 ] == 1
722+ b_is_k_major = mat_b .get_stride ()[- 2 ] == 1
723+
651724 triton_has_make_tensor_descriptor = hasattr (tl , "make_tensor_descriptor" )
652725 triton_has_experimental_make_tensor_descriptor = hasattr (
653726 tl , "_experimental_make_tensor_descriptor"
@@ -656,22 +729,21 @@ def _tuned_grouped_mm_common(
656729 triton_has_make_tensor_descriptor
657730 or triton_has_experimental_make_tensor_descriptor
658731 )
659- # The make_tensor_descriptor imposes this additional limitation.
660- use_tma_load = use_tma_load and (
661- mat_a .get_stride ()[- 1 ] == 1 and mat_b .get_stride ()[- 2 ] == 1
662- )
663-
664732 kwargs = {
665733 "SCALED" : scaled ,
666734 "A_IS_2D" : a_is_2d ,
667735 "B_IS_2D" : b_is_2d ,
736+ "A_IS_K_MAJOR" : a_is_k_major ,
737+ "B_IS_K_MAJOR" : b_is_k_major ,
668738 "USE_FAST_ACCUM" : use_fast_accum ,
669739 "NUM_SMS" : get_num_sms (),
670740 "USE_TMA_LOAD" : use_tma_load ,
671741 "USE_EXPERIMENTAL_MAKE_TENSOR_DESCRIPTOR" : triton_has_experimental_make_tensor_descriptor ,
672742 }
673743
674- for config in early_config_prune (g , m , grouped_mm_configs (), kwargs ):
744+ for config in early_config_prune (
745+ g , m , mat_a .dtype .itemsize , grouped_mm_configs (), kwargs
746+ ):
675747 kernel_template .maybe_append_choice (
676748 choices ,
677749 input_nodes = input_nodes ,
0 commit comments