Skip to content

Commit 73c6e33

Browse files
authored
Extend blackwell_matmul_descriptor_persistent to support transposed input
Differential Revision: D78756024 Pull Request resolved: #300
1 parent 863a713 commit 73c6e33

File tree

1 file changed

+53
-31
lines changed

1 file changed

+53
-31
lines changed

tritonbench/operators/gemm/warp_spec_persistent_matmul.py

Lines changed: 53 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def matmul_kernel_tma(
121121
GROUP_SIZE_M: tl.constexpr, #
122122
WARP_SPECIALIZE: tl.constexpr, #
123123
DTYPE: tl.constexpr,
124+
IS_TRANSPOSE: tl.constexpr,
124125
):
125126
dtype = DTYPE
126127

@@ -145,7 +146,11 @@ def matmul_kernel_tma(
145146
offs_k = k * BLOCK_SIZE_K
146147
a = a_desc.load([offs_am, offs_k])
147148
b = b_desc.load([offs_bn, offs_k])
148-
accumulator = tl.dot(a, b.T, accumulator)
149+
if IS_TRANSPOSE:
150+
arg2 = b
151+
else:
152+
arg2 = b.T
153+
accumulator = tl.dot(a, arg2, accumulator)
149154

150155
c = accumulator.to(dtype)
151156

@@ -166,13 +171,7 @@ def warn_once(msg: str):
166171

167172

168173
def blackwell_matmul_tma(a, b, warp_specialize: bool):
169-
# Check constraints.
170-
if a.shape[1] != b.shape[1]:
171-
warn_once(
172-
"Incompatible dimensions, B is transposed. We are transposing B which may impact results"
173-
)
174-
b = b.T.contiguous()
175-
assert a.shape[1] == b.shape[1], "Incompatible dimensions" # b is transposed
174+
is_transpose = a.shape[1] != b.shape[1]
176175
assert a.dtype == b.dtype, "Incompatible dtypes"
177176

178177
M, K = a.shape
@@ -201,6 +200,7 @@ def grid(META):
201200
K, #
202201
WARP_SPECIALIZE=warp_specialize, #
203202
DTYPE=torch_dtype_to_triton_dtype(dtype), #
203+
IS_TRANSPOSE=is_transpose,
204204
)
205205
return c
206206

@@ -258,6 +258,7 @@ def matmul_kernel_tma_persistent(
258258
NUM_SMS: tl.constexpr, #
259259
WARP_SPECIALIZE: tl.constexpr, #
260260
DTYPE: tl.constexpr,
261+
IS_TRANSPOSE: tl.constexpr,
261262
):
262263
dtype = DTYPE
263264
start_pid = tl.program_id(axis=0)
@@ -286,7 +287,11 @@ def matmul_kernel_tma_persistent(
286287
offs_k = ki * BLOCK_SIZE_K
287288
a = a_desc.load([offs_am, offs_k])
288289
b = b_desc.load([offs_bn, offs_k])
289-
accumulator = tl.dot(a, b.T, accumulator)
290+
if IS_TRANSPOSE:
291+
arg2 = b
292+
else:
293+
arg2 = b.T
294+
accumulator = tl.dot(a, arg2, accumulator)
290295

291296
tile_id_c += NUM_SMS
292297
pid_m, pid_n = _compute_pid(
@@ -313,13 +318,7 @@ def matmul_kernel_tma_persistent(
313318

314319

315320
def blackwell_matmul_tma_persistent(a, b, warp_specialize: bool):
316-
# Check constraints.
317-
if a.shape[1] != b.shape[1]:
318-
warn_once(
319-
"Incompatible dimensions, B is transposed. We are transposing B which may impact results"
320-
)
321-
b = b.T.contiguous()
322-
assert a.shape[1] == b.shape[1], "Incompatible dimensions" # b is transposed
321+
is_transpose = a.shape[1] != b.shape[1]
323322
assert a.dtype == b.dtype, "Incompatible dtypes"
324323

325324
check_tma_alignment(a.stride(), (torch.finfo(a.dtype).bits + 7) // 8)
@@ -360,6 +359,7 @@ def grid(META):
360359
NUM_SMS=NUM_SMS, #
361360
WARP_SPECIALIZE=warp_specialize, #
362361
DTYPE=torch_dtype_to_triton_dtype(dtype), #
362+
IS_TRANSPOSE=is_transpose,
363363
)
364364
return c
365365

@@ -395,6 +395,7 @@ def matmul_kernel_descriptor_persistent(
395395
NUM_SMS: tl.constexpr, #
396396
WARP_SPECIALIZE: tl.constexpr, #
397397
FLATTEN: tl.constexpr,
398+
TRANSPOSE_B: tl.constexpr,
398399
):
399400
# Matmul using TMA and device-side descriptor creation
400401
dtype = c_ptr.dtype.element_ty
@@ -410,12 +411,20 @@ def matmul_kernel_descriptor_persistent(
410411
strides=[K, 1],
411412
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
412413
)
413-
b_desc = tl.make_tensor_descriptor(
414-
b_ptr,
415-
shape=[N, K],
416-
strides=[K, 1],
417-
block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K],
418-
)
414+
if TRANSPOSE_B:
415+
b_desc = tl.make_tensor_descriptor(
416+
b_ptr,
417+
shape=[N, K],
418+
strides=[K, 1],
419+
block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K],
420+
)
421+
else:
422+
b_desc = tl.make_tensor_descriptor(
423+
b_ptr,
424+
shape=[K, N],
425+
strides=[N, 1],
426+
block_shape=[BLOCK_SIZE_K, BLOCK_SIZE_N],
427+
)
419428
c_desc = tl.make_tensor_descriptor(
420429
c_ptr,
421430
shape=[M, N],
@@ -445,7 +454,11 @@ def matmul_kernel_descriptor_persistent(
445454
offs_k = ki * BLOCK_SIZE_K
446455
a = a_desc.load([offs_am, offs_k])
447456
b = b_desc.load([offs_bn, offs_k])
448-
accumulator = tl.dot(a, b.T, accumulator)
457+
if TRANSPOSE_B:
458+
arg2 = b.T
459+
else:
460+
arg2 = b
461+
accumulator = tl.dot(a, arg2, accumulator)
449462

450463
tile_id_c += NUM_SMS
451464
pid_m, pid_n = _compute_pid(
@@ -468,17 +481,25 @@ def matmul_kernel_descriptor_persistent(
468481

469482

470483
def blackwell_matmul_descriptor_persistent(a, b, warp_specialize: bool):
471-
# Check constraints.
472-
if a.shape[1] != b.shape[1]:
473-
warn_once(
474-
"Incompatible dimensions, B is transposed. We are transposing B which may impact results"
475-
)
476-
b = b.T.contiguous()
477-
assert a.shape[1] == b.shape[1], "Incompatible dimensions" # b is transposed
484+
# High-Level Options for B's layout
485+
# 1. (K, N) contiguous in N
486+
# 2. (K, N) contiguous in K
487+
# 3. (N, K) contiguous in N
488+
# 4. (N, K) contiguous in K
489+
# In practice, since you always load in the contiguous dimension
490+
# there are actually only 2 options
491+
# 1. Load in the K stride 1 (2 and 4)
492+
# 2. Load in the N stride 1 (1 and 3)
493+
transpose_b = (a.shape[1] != b.shape[1] and b.stride()[-1] != 1) or (
494+
a.shape[1] == b.shape[1] and b.stride()[-1] == 1
495+
)
478496
assert a.dtype == b.dtype, "Incompatible dtypes"
479497

480498
M, K = a.shape
481-
N, K = b.shape
499+
if a.shape[1] != b.shape[1]:
500+
K, N = b.shape
501+
else:
502+
N, K = b.shape
482503
dtype = a.dtype
483504

484505
c = torch.empty((M, N), device=a.device, dtype=dtype)
@@ -507,5 +528,6 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]):
507528
WARP_SPECIALIZE=warp_specialize, #
508529
# Note: This assumes blackwell.
509530
FLATTEN=True,
531+
TRANSPOSE_B=transpose_b,
510532
)
511533
return c

0 commit comments

Comments
 (0)