@@ -121,6 +121,7 @@ def matmul_kernel_tma(
121
121
GROUP_SIZE_M : tl .constexpr , #
122
122
WARP_SPECIALIZE : tl .constexpr , #
123
123
DTYPE : tl .constexpr ,
124
+ IS_TRANSPOSE : tl .constexpr ,
124
125
):
125
126
dtype = DTYPE
126
127
@@ -145,7 +146,11 @@ def matmul_kernel_tma(
145
146
offs_k = k * BLOCK_SIZE_K
146
147
a = a_desc .load ([offs_am , offs_k ])
147
148
b = b_desc .load ([offs_bn , offs_k ])
148
- accumulator = tl .dot (a , b .T , accumulator )
149
+ if IS_TRANSPOSE :
150
+ arg2 = b
151
+ else :
152
+ arg2 = b .T
153
+ accumulator = tl .dot (a , arg2 , accumulator )
149
154
150
155
c = accumulator .to (dtype )
151
156
@@ -166,13 +171,7 @@ def warn_once(msg: str):
166
171
167
172
168
173
def blackwell_matmul_tma (a , b , warp_specialize : bool ):
169
- # Check constraints.
170
- if a .shape [1 ] != b .shape [1 ]:
171
- warn_once (
172
- "Incompatible dimensions, B is transposed. We are transposing B which may impact results"
173
- )
174
- b = b .T .contiguous ()
175
- assert a .shape [1 ] == b .shape [1 ], "Incompatible dimensions" # b is transposed
174
+ is_transpose = a .shape [1 ] != b .shape [1 ]
176
175
assert a .dtype == b .dtype , "Incompatible dtypes"
177
176
178
177
M , K = a .shape
@@ -201,6 +200,7 @@ def grid(META):
201
200
K , #
202
201
WARP_SPECIALIZE = warp_specialize , #
203
202
DTYPE = torch_dtype_to_triton_dtype (dtype ), #
203
+ IS_TRANSPOSE = is_transpose ,
204
204
)
205
205
return c
206
206
@@ -258,6 +258,7 @@ def matmul_kernel_tma_persistent(
258
258
NUM_SMS : tl .constexpr , #
259
259
WARP_SPECIALIZE : tl .constexpr , #
260
260
DTYPE : tl .constexpr ,
261
+ IS_TRANSPOSE : tl .constexpr ,
261
262
):
262
263
dtype = DTYPE
263
264
start_pid = tl .program_id (axis = 0 )
@@ -286,7 +287,11 @@ def matmul_kernel_tma_persistent(
286
287
offs_k = ki * BLOCK_SIZE_K
287
288
a = a_desc .load ([offs_am , offs_k ])
288
289
b = b_desc .load ([offs_bn , offs_k ])
289
- accumulator = tl .dot (a , b .T , accumulator )
290
+ if IS_TRANSPOSE :
291
+ arg2 = b
292
+ else :
293
+ arg2 = b .T
294
+ accumulator = tl .dot (a , arg2 , accumulator )
290
295
291
296
tile_id_c += NUM_SMS
292
297
pid_m , pid_n = _compute_pid (
@@ -313,13 +318,7 @@ def matmul_kernel_tma_persistent(
313
318
314
319
315
320
def blackwell_matmul_tma_persistent (a , b , warp_specialize : bool ):
316
- # Check constraints.
317
- if a .shape [1 ] != b .shape [1 ]:
318
- warn_once (
319
- "Incompatible dimensions, B is transposed. We are transposing B which may impact results"
320
- )
321
- b = b .T .contiguous ()
322
- assert a .shape [1 ] == b .shape [1 ], "Incompatible dimensions" # b is transposed
321
+ is_transpose = a .shape [1 ] != b .shape [1 ]
323
322
assert a .dtype == b .dtype , "Incompatible dtypes"
324
323
325
324
check_tma_alignment (a .stride (), (torch .finfo (a .dtype ).bits + 7 ) // 8 )
@@ -360,6 +359,7 @@ def grid(META):
360
359
NUM_SMS = NUM_SMS , #
361
360
WARP_SPECIALIZE = warp_specialize , #
362
361
DTYPE = torch_dtype_to_triton_dtype (dtype ), #
362
+ IS_TRANSPOSE = is_transpose ,
363
363
)
364
364
return c
365
365
@@ -395,6 +395,7 @@ def matmul_kernel_descriptor_persistent(
395
395
NUM_SMS : tl .constexpr , #
396
396
WARP_SPECIALIZE : tl .constexpr , #
397
397
FLATTEN : tl .constexpr ,
398
+ TRANSPOSE_B : tl .constexpr ,
398
399
):
399
400
# Matmul using TMA and device-side descriptor creation
400
401
dtype = c_ptr .dtype .element_ty
@@ -410,12 +411,20 @@ def matmul_kernel_descriptor_persistent(
410
411
strides = [K , 1 ],
411
412
block_shape = [BLOCK_SIZE_M , BLOCK_SIZE_K ],
412
413
)
413
- b_desc = tl .make_tensor_descriptor (
414
- b_ptr ,
415
- shape = [N , K ],
416
- strides = [K , 1 ],
417
- block_shape = [BLOCK_SIZE_N , BLOCK_SIZE_K ],
418
- )
414
+ if TRANSPOSE_B :
415
+ b_desc = tl .make_tensor_descriptor (
416
+ b_ptr ,
417
+ shape = [N , K ],
418
+ strides = [K , 1 ],
419
+ block_shape = [BLOCK_SIZE_N , BLOCK_SIZE_K ],
420
+ )
421
+ else :
422
+ b_desc = tl .make_tensor_descriptor (
423
+ b_ptr ,
424
+ shape = [K , N ],
425
+ strides = [N , 1 ],
426
+ block_shape = [BLOCK_SIZE_K , BLOCK_SIZE_N ],
427
+ )
419
428
c_desc = tl .make_tensor_descriptor (
420
429
c_ptr ,
421
430
shape = [M , N ],
@@ -445,7 +454,11 @@ def matmul_kernel_descriptor_persistent(
445
454
offs_k = ki * BLOCK_SIZE_K
446
455
a = a_desc .load ([offs_am , offs_k ])
447
456
b = b_desc .load ([offs_bn , offs_k ])
448
- accumulator = tl .dot (a , b .T , accumulator )
457
+ if TRANSPOSE_B :
458
+ arg2 = b .T
459
+ else :
460
+ arg2 = b
461
+ accumulator = tl .dot (a , arg2 , accumulator )
449
462
450
463
tile_id_c += NUM_SMS
451
464
pid_m , pid_n = _compute_pid (
@@ -468,17 +481,25 @@ def matmul_kernel_descriptor_persistent(
468
481
469
482
470
483
def blackwell_matmul_descriptor_persistent (a , b , warp_specialize : bool ):
471
- # Check constraints.
472
- if a .shape [1 ] != b .shape [1 ]:
473
- warn_once (
474
- "Incompatible dimensions, B is transposed. We are transposing B which may impact results"
475
- )
476
- b = b .T .contiguous ()
477
- assert a .shape [1 ] == b .shape [1 ], "Incompatible dimensions" # b is transposed
484
+ # High-Level Options for B's layout
485
+ # 1. (K, N) contiguous in N
486
+ # 2. (K, N) contiguous in K
487
+ # 3. (N, K) contiguous in N
488
+ # 4. (N, K) contiguous in K
489
+ # In practice, since you always load in the contiguous dimension
490
+ # there are actually only 2 options
491
+ # 1. Load in the K stride 1 (2 and 4)
492
+ # 2. Load in the N stride 1 (1 and 3)
493
+ transpose_b = (a .shape [1 ] != b .shape [1 ] and b .stride ()[- 1 ] != 1 ) or (
494
+ a .shape [1 ] == b .shape [1 ] and b .stride ()[- 1 ] == 1
495
+ )
478
496
assert a .dtype == b .dtype , "Incompatible dtypes"
479
497
480
498
M , K = a .shape
481
- N , K = b .shape
499
+ if a .shape [1 ] != b .shape [1 ]:
500
+ K , N = b .shape
501
+ else :
502
+ N , K = b .shape
482
503
dtype = a .dtype
483
504
484
505
c = torch .empty ((M , N ), device = a .device , dtype = dtype )
@@ -507,5 +528,6 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]):
507
528
WARP_SPECIALIZE = warp_specialize , #
508
529
# Note: This assumes blackwell.
509
530
FLATTEN = True ,
531
+ TRANSPOSE_B = transpose_b ,
510
532
)
511
533
return c
0 commit comments