Skip to content

Commit db2155c

Browse files
authored
Add relative and absolute tolerance as command-line arguments
Differential Revision: D80137287 Pull Request resolved: #329
1 parent 5d3279c commit db2155c

File tree

4 files changed

+65
-15
lines changed

4 files changed

+65
-15
lines changed

tritonbench/operators/gemm/operator.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,9 @@ def pt2_matmul_maxautotune(self, a, b, bias) -> Callable:
313313

314314
@register_benchmark(enabled=not is_cuda())
315315
def streamk_matmul(self, a, b, bias) -> Callable:
316-
return lambda: streamk_amd_matmul(a, b, bias) if bias else streamk_amd_matmul(a, b)
316+
return (
317+
lambda: streamk_amd_matmul(a, b, bias) if bias else streamk_amd_matmul(a, b)
318+
)
317319

318320
@register_benchmark(enabled=is_cuda())
319321
def streamk_matmul(self, a, b, bias) -> Callable:
@@ -322,8 +324,14 @@ def streamk_matmul(self, a, b, bias) -> Callable:
322324
b = b.T.contiguous()
323325
baseline = streamk_cuda_matmul(a, b)
324326
if not torch.allclose(streamk, baseline):
325-
print(f"StreamK matmul on {a.shape} x {b.shape} result does not match baseline matmul result. Max abs(streamk/baseline - 1): {torch.max(torch.abs(streamk / baseline - 1))}")
326-
return lambda: streamk_cuda_matmul(a, b) + bias if bias else streamk_cuda_matmul(a, b)
327+
print(
328+
f"StreamK matmul on {a.shape} x {b.shape} result does not match baseline matmul result. Max abs(streamk/baseline - 1): {torch.max(torch.abs(streamk / baseline - 1))}"
329+
)
330+
return (
331+
lambda: streamk_cuda_matmul(a, b) + bias
332+
if bias
333+
else streamk_cuda_matmul(a, b)
334+
)
327335

328336
@register_benchmark(enabled=is_cuda())
329337
def pt2_cutlass_matmul(self, a, b, bias) -> Callable:

tritonbench/operators/gemm/stream_k.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,7 @@ def streamk_amd_matmul(a, b, bias=None):
390390
# print(a @ b)
391391
return c
392392

393+
393394
def _matmul_launch_metadata(grid, kernel, args):
394395
ret = {}
395396
M, N, K = args["M"], args["N"], args["K"]
@@ -406,19 +407,26 @@ def _matmul_launch_metadata(grid, kernel, args):
406407
def matmul_get_configs(pre_hook=None):
407408
return [
408409
triton.Config(
409-
{"BLOCK_M": BM, "BLOCK_N": BN, "BLOCK_K": BK, "SK_BLOCK_K": skBK, "GROUP_M": 8},
410+
{
411+
"BLOCK_M": BM,
412+
"BLOCK_N": BN,
413+
"BLOCK_K": BK,
414+
"SK_BLOCK_K": skBK,
415+
"GROUP_M": 8,
416+
},
410417
num_stages=s,
411418
num_warps=w,
412419
pre_hook=pre_hook,
413420
) #
414421
for BM in [128, 256] #
415422
for BN in [128, 256] #
416423
for BK in [32, 64, 128] #
417-
for skBK in [16, 32, 64, 128] #
424+
for skBK in [16, 32, 64, 128] #
418425
for s in ([2, 3, 4]) #
419426
for w in [4, 8] #
420427
]
421428

429+
422430
def matmul_tma_set_block_size_hook(nargs):
423431
BLOCK_M = nargs["BLOCK_M"]
424432
BLOCK_N = nargs["BLOCK_N"]
@@ -431,6 +439,7 @@ def matmul_tma_set_block_size_hook(nargs):
431439
nargs["a_desc_sk"].block_shape = [BLOCK_M, SK_BLOCK_K]
432440
nargs["b_desc_sk"].block_shape = [BLOCK_N, SK_BLOCK_K]
433441

442+
434443
@triton.autotune(
435444
configs=matmul_get_configs(pre_hook=matmul_tma_set_block_size_hook),
436445
key=["M", "N", "K"],
@@ -494,7 +503,6 @@ def streamk_cuda_gemm(
494503
total_ddp_tiles = num_pid - NUM_SMS
495504
streamk_sms = NUM_SMS
496505

497-
498506
# ----------------------------------------------------------------------------
499507
# DDP phase
500508
# ----------------------------------------------------------------------------
@@ -534,12 +542,12 @@ def streamk_cuda_gemm(
534542

535543
# `evenly` distribute work units across SMs, with rem tiles assigned contiguously to the first rem programs
536544
base = total_work_units // streamk_sms
537-
rem = total_work_units % streamk_sms
545+
rem = total_work_units % streamk_sms
538546
work = tl.where(worker_id < rem, base + 1, base)
539547
start = tl.where(
540548
worker_id < rem,
541549
worker_id * (base + 1),
542-
rem * (base + 1) + (worker_id - rem) * base
550+
rem * (base + 1) + (worker_id - rem) * base,
543551
)
544552
end = start + work - 1
545553

@@ -567,7 +575,9 @@ def streamk_cuda_gemm(
567575

568576
# compute the start and end K index on this tile for this work unit
569577
curr_st_k = tl.where(curr_tile == st_tile_streamk, st_k_streamk, 0)
570-
curr_en_k = tl.where(curr_tile == en_tile_streamk, en_k_streamk, work_units_per_tile - 1)
578+
curr_en_k = tl.where(
579+
curr_tile == en_tile_streamk, en_k_streamk, work_units_per_tile - 1
580+
)
571581

572582
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
573583

@@ -590,6 +600,7 @@ def streamk_cuda_gemm(
590600
# NOTE: known correctness issue with atomic_add
591601
c_desc.atomic_add([offs_am, offs_bn], c)
592602

603+
593604
def streamk_cuda_matmul(a, b):
594605
assert a.dtype == b.dtype, "Incompatible dtypes"
595606

@@ -624,7 +635,6 @@ def grid(META):
624635
streamk_sms = num_sms
625636
return (total_ddp_tiles + streamk_sms,)
626637

627-
628638
streamk_cuda_gemm[grid](
629639
a_desc,
630640
b_desc,
@@ -636,6 +646,6 @@ def grid(META):
636646
K, #
637647
FP8_OUTPUT=dtype == torch.float8_e4m3fn, #
638648
ENABLE_BUFFER_OPS_ASSUMES=True, #
639-
NUM_SMS=num_sms #
649+
NUM_SMS=num_sms #
640650
)
641651
return c

tritonbench/utils/parser.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,18 @@ def get_parser(args=None):
246246
default=None,
247247
help="Name of group for benchmarking.",
248248
)
249+
parser.add_argument(
250+
"--rtol",
251+
type=float,
252+
default=None,
253+
help="Relative tolerance for accuracy metric.",
254+
)
255+
parser.add_argument(
256+
"--atol",
257+
type=float,
258+
default=None,
259+
help="Absolute tolerance for accuracy metric.",
260+
)
249261

250262
# A/B Testing parameters
251263
parser.add_argument(

tritonbench/utils/triton_op.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,14 +1145,34 @@ def accuracy(self, fn: Callable, baseline_fn: Callable) -> bool:
11451145
baseline_output = baseline_fn()
11461146
try:
11471147
if self.mode == Mode.FWD:
1148-
torch.testing.assert_close(output, baseline_output)
1148+
torch.testing.assert_close(
1149+
output,
1150+
baseline_output,
1151+
rtol=self.tb_args.rtol,
1152+
atol=self.tb_args.atol,
1153+
)
11491154
elif self.mode == Mode.BWD:
1150-
torch.testing.assert_close(output.grad, baseline_output.grad)
1155+
torch.testing.assert_close(
1156+
output.grad,
1157+
baseline_output.grad,
1158+
rtol=self.tb_args.rtol,
1159+
atol=self.tb_args.atol,
1160+
)
11511161
else:
11521162
fwd_output, loss = output
11531163
baseline_fwd_output, baseline_loss = baseline_output
1154-
torch.testing.assert_close(fwd_output, baseline_fwd_output)
1155-
torch.testing.assert_close(loss.grad, baseline_loss.grad)
1164+
torch.testing.assert_close(
1165+
fwd_output,
1166+
baseline_fwd_output,
1167+
rtol=self.tb_args.rtol,
1168+
atol=self.tb_args.atol,
1169+
)
1170+
torch.testing.assert_close(
1171+
loss.grad,
1172+
baseline_loss.grad,
1173+
rtol=self.tb_args.rtol,
1174+
atol=self.tb_args.atol,
1175+
)
11561176
return True
11571177
except Exception:
11581178
# either the output tensor or the loss grad tensor does not match

0 commit comments

Comments
 (0)