Skip to content

Commit 76c7f03

Browse files
authored
Merge branch 'main' into tkuczynski/enable_test_small_batch_matmul
2 parents 421072e + d6b921e commit 76c7f03

File tree

85 files changed

+2566
-832
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

85 files changed

+2566
-832
lines changed

.github/workflows/integration-tests-amd.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,9 @@ jobs:
122122
pytest --capture=tee-sys -rfs -n 8 python/test/gluon/
123123
124124
pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py
125-
pytest --capture=tee-sys -rfs third_party/amd/python/test/test_extract_slice_concat_op.py
125+
pytest --capture=tee-sys -rfs -n 8 third_party/amd/python/test/ \
126+
--ignore=third_party/amd/python/test/test_scalarize_packed_fops.py \
127+
--ignore=third_party/amd/python/test/test_address_sanitizer.py
126128
TRITON_ALWAYS_COMPILE=1 pytest --capture=tee-sys -rfs third_party/amd/python/test/test_scalarize_packed_fops.py
127129
cd python/test/unit
128130
pytest --capture=tee-sys -rfs -n 12 \

.github/workflows/runner-preparation.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ jobs:
3939
- name: Detect if build deps (e.g. LLVM hash) changed
4040
id: detect-change
4141
if: github.event_name == 'push'
42-
uses: tj-actions/changed-files@v46
42+
uses: tj-actions/changed-files@v47
4343
with:
4444
files: |
4545
cmake/*.txt

benchmarks/triton_kernels_benchmark/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from .benchmark_testing import (
44
assert_close,
55
do_bench,
6-
do_prewarmup,
76
filter_providers,
87
perf_report,
98
Benchmark,
@@ -20,7 +19,6 @@
2019
__all__ = [
2120
"assert_close",
2221
"do_bench",
23-
"do_prewarmup",
2422
"filter_providers",
2523
"perf_report",
2624
"Benchmark",

benchmarks/triton_kernels_benchmark/benchmark_testing.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
BENCHMARKING_METHOD = os.getenv("BENCHMARKING_METHOD", "UPSTREAM_PYTORCH_PROFILER")
2828
BENCHMARKING_CONFIG = {
2929
"verify": os.getenv("VERIFY", "1") == "1",
30-
"do_prewarmup": os.getenv("PREWARMUP", "1") == "1",
3130
}
3231

3332

@@ -42,19 +41,6 @@ def synchronize():
4241
torch.xpu.synchronize()
4342

4443

45-
def do_prewarmup(fn, min_seconds=5):
46-
"""Looks like some functions require pre-warmup with minimum time to do the compilation.
47-
It has to be done once."""
48-
if not BENCHMARKING_CONFIG["do_prewarmup"]:
49-
return
50-
51-
start = time.time()
52-
while time.time() - start < min_seconds:
53-
fn()
54-
synchronize()
55-
BENCHMARKING_CONFIG["do_prewarmup"] = False
56-
57-
5844
def _summarize_statistics(times, quantiles, return_mode):
5945
if quantiles is not None:
6046
ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist()
@@ -73,7 +59,7 @@ def _summarize_statistics(times, quantiles, return_mode):
7359

7460

7561
def do_bench_elapsed_time(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quantiles=None, return_mode="mean",
76-
device="xpu"):
62+
device="xpu", time_warmup=False):
7763
"""
7864
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
7965
the 20-th and 80-th performance percentile.
@@ -113,16 +99,20 @@ def do_bench_elapsed_time(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quan
11399
del cache
114100

115101
# compute warmup and repeat times
116-
warmup_time = n_warmup * estimate_ms
102+
if time_warmup:
103+
warmup_ms = n_warmup
104+
else:
105+
warmup_ms = n_warmup * estimate_ms
117106
rep_time = n_repeat * estimate_ms
118107

119-
times = triton_do_bench(fn, warmup=warmup_time, rep=rep_time, grad_to_none=grad_to_none, return_mode="all")
108+
times = triton_do_bench(fn, warmup=warmup_ms, rep=rep_time, grad_to_none=grad_to_none, return_mode="all")
120109
times = torch.tensor(times, dtype=torch.float)
121110
return _summarize_statistics(times, quantiles, return_mode)
122111

123112

124113
def do_bench_upstream_pytorch_profiler(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quantiles=None,
125-
return_mode="mean", device="xpu", sync_submitting=True, benchmark_label=None):
114+
return_mode="mean", device="xpu", sync_submitting=True, time_warmup=True,
115+
benchmark_label=None, max_iters=1500):
126116
"""
127117
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
128118
the 20-th and 80-th performance percentile.
@@ -151,11 +141,23 @@ def do_bench_upstream_pytorch_profiler(fn, n_warmup=25, n_repeat=100, grad_to_no
151141
cache = torch.empty(int(cache_size // 4), dtype=torch.int, device=device)
152142

153143
# Warm-up
154-
for _ in range(n_warmup):
155-
fn()
156-
# To be consistent with the benchmark measurements
157-
if sync_submitting:
144+
if time_warmup:
145+
# Stop either on max iteration number or max time
146+
warmup_time_s = n_warmup / 1000
147+
assert sync_submitting
148+
start = time.perf_counter()
149+
i = 0
150+
while i < max_iters and time.perf_counter() - start < warmup_time_s:
151+
fn()
158152
synchronize()
153+
i += 1
154+
print(f"Stopped warmup after {i} iterations")
155+
else:
156+
for _ in range(n_warmup):
157+
fn()
158+
# To be consistent with the benchmark measurements
159+
if sync_submitting:
160+
synchronize()
159161

160162
# Benchmark
161163
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.XPU]) as prof:

benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,15 @@ def get_benchmark(
575575
# pylint: disable=too-many-branches
576576
def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, MODE, provider):
577577
modes = ['fwd', 'bwd']
578+
# This warmup logic improves performance on BMG significantly
579+
# For FWD mode in triton & cutlass: Some configs increase performance with warmup as a step function, but some slowly decrease with saturation
580+
# Performance is best at 250-400ms range, but we want stable, not just best at ~600ms (triton/cutlass providers)
581+
n_warmup_fwd = 600
582+
# For BWD mode: Performance doesn't really improve much with warmup for triton, but xetla benefit from more warmup
583+
n_warmup_bwd = 400 # Maximum across xetla=400, triton=10, onednn=10
584+
n_warmup = n_warmup_fwd if MODE == 'fwd' else n_warmup_bwd
585+
# We keep old warmup value, because new warmup makes perfomance on PVC slightly worse
586+
n_warmup = 10
578587
if MODE not in modes:
579588
raise AssertionError(f'Unknown {MODE}, supported modes are {modes}')
580589
dtype = torch.float16
@@ -602,9 +611,10 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, MODE, provider):
602611
if provider == 'onednn':
603612
_, min_ms, max_ms, mean, cv = benchmark_suite.do_bench(
604613
torch_fn,
605-
n_warmup=10,
614+
n_warmup=n_warmup,
606615
n_repeat=10,
607616
quantiles=quantiles,
617+
time_warmup=False,
608618
)
609619

610620
elif provider == 'triton':
@@ -623,11 +633,13 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, MODE, provider):
623633
rtol=0,
624634
err_msg='triton to torch',
625635
)
636+
626637
_, min_ms, max_ms, mean, cv = benchmark_suite.do_bench(
627638
triton_fn,
628-
n_warmup=10,
639+
n_warmup=n_warmup,
629640
n_repeat=10,
630641
quantiles=quantiles,
642+
time_warmup=False,
631643
)
632644

633645
elif provider == 'xetla':
@@ -660,9 +672,10 @@ def xetla_bwd_fn():
660672

661673
_, min_ms, max_ms, mean, cv = benchmark_suite.do_bench(
662674
xetla_bwd_fn,
663-
n_warmup=10,
675+
n_warmup=n_warmup,
664676
n_repeat=10,
665677
quantiles=quantiles,
678+
time_warmup=False,
666679
)
667680

668681
else:
@@ -685,9 +698,10 @@ def cutlass_fwd_fn():
685698

686699
_, min_ms, max_ms, mean, cv = benchmark_suite.do_bench(
687700
cutlass_fwd_fn,
688-
n_warmup=10,
701+
n_warmup=n_warmup,
689702
n_repeat=10,
690703
quantiles=quantiles,
704+
time_warmup=False,
691705
)
692706

693707
else:

benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ def causal_mask(_, __, q_idx, kv_idx):
137137
args={},
138138
))
139139
def benchmark(Z, H_q, H_kv, N_CTX_q, N_CTX_kv, D_HEAD_qk, D_HEAD_v, MODE, provider):
140+
# Maximum across torch=200, triton=600
141+
n_warmup = 600
140142
if MODE not in ('fwd', 'bwd'):
141143
raise ValueError(f"Invalid MODE: {MODE}. Expected 'fwd' or 'bwd'.")
142144
dtype = torch.float16
@@ -156,7 +158,7 @@ def benchmark(Z, H_q, H_kv, N_CTX_q, N_CTX_kv, D_HEAD_qk, D_HEAD_v, MODE, provid
156158
mean = float('nan')
157159
cv = float('nan')
158160
else:
159-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(torch_fn, n_warmup=10, n_repeat=10,
161+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(torch_fn, n_warmup=n_warmup, n_repeat=10,
160162
quantiles=quantiles, device=DEVICE)
161163

162164
elif provider == 'triton':
@@ -181,10 +183,8 @@ def benchmark(Z, H_q, H_kv, N_CTX_q, N_CTX_kv, D_HEAD_qk, D_HEAD_v, MODE, provid
181183
else:
182184
benchmark_suit.assert_close(triton_fn, torch_fn, atol=1e-2, rtol=1e-3, err_msg='triton to torch')
183185

184-
# Needs more warmup on B580 for some reason
185-
benchmark_suit.do_prewarmup(triton_fn)
186186
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(
187-
triton_fn, n_warmup=200, n_repeat=10, quantiles=quantiles, device=DEVICE, grad_to_none=(q, k, v),
187+
triton_fn, n_warmup=n_warmup, n_repeat=10, quantiles=quantiles, device=DEVICE, grad_to_none=(q, k, v),
188188
benchmark_label=None if MODE == 'fwd' else 'CompiledFunctionBackward')
189189

190190
else:

benchmarks/triton_kernels_benchmark/flex_attention_benchmark_custom_masks.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ def alibi_functional(score, _, h, q_idx, kv_idx):
8282
args={},
8383
))
8484
def benchmark(Z, H, N_CTX, D_HEAD, MASK, MODE, provider):
85+
# There is still performance variance for triton, probably caused by random choice of autotune config
86+
n_warmup = 200
8587
assert MODE in ['fwd', 'bwd']
8688
assert MASK in ['NATTEN', 'Alibi']
8789
dtype = torch.float16
@@ -112,9 +114,8 @@ def benchmark(Z, H, N_CTX, D_HEAD, MASK, MODE, provider):
112114
triton_o = triton_fn()
113115
triton_do = torch.randn_like(triton_o)
114116
triton_fn = lambda: triton_o.backward(triton_do, retain_graph=True)
115-
# Needs more warmup on B580 for some reason
116-
benchmark_suit.do_prewarmup(triton_fn)
117-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=5, quantiles=quantiles)
117+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=n_warmup, n_repeat=10,
118+
quantiles=quantiles)
118119
# Values checking cannot be implemented for these case as :
119120
# "The operator 'aten::_scaled_dot_product_flash_attention_for_cpu' is not currently implemented for the XPU device"
120121

@@ -124,7 +125,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, MASK, MODE, provider):
124125
xformers_o = xformers_fn()
125126
xformers_do = torch.randn_like(xformers_o)
126127
xformers_fn = lambda: xformers_o.backward(xformers_do, retain_graph=True)
127-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xformers_fn, n_warmup=10, n_repeat=10,
128+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xformers_fn, n_warmup=n_warmup, n_repeat=10,
128129
quantiles=quantiles)
129130

130131
else:

benchmarks/triton_kernels_benchmark/fused_softmax.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,16 @@ def get_benchmark(providers_filter: Optional[list[str]] = None):
128128
args={"M": 4096}, # values for function arguments not in `x_names` and `y_name`
129129
))
130130
def benchmark(M, N, provider):
131+
# Maximum across torch-native=10, triton=800, torch-jit=10, xetla=100, onednn=800
132+
# For onednn more warmup very slowly makes performance worse
133+
n_warmup = 800
131134
x = torch.randn(M, N, device="xpu", dtype=torch.bfloat16)
132135
quantiles = [0.5, 0.0, 1.0]
133136
if provider == "torch-native":
134137
_, min_ms, max_ms, mean, cv = benchmark_suite.do_bench(
135138
lambda: torch.softmax(x, axis=-1),
136139
quantiles=quantiles,
137-
n_warmup=10,
140+
n_warmup=n_warmup,
138141
n_repeat=10,
139142
)
140143
if provider == "triton":
@@ -145,13 +148,13 @@ def benchmark(M, N, provider):
145148
_, min_ms, max_ms, mean, cv = benchmark_suite.do_bench(
146149
triton_fn,
147150
quantiles=quantiles,
148-
n_warmup=10,
151+
n_warmup=n_warmup,
149152
n_repeat=10,
150153
)
151154

152155
elif provider == "torch-jit":
153156
_, min_ms, max_ms, mean, cv = benchmark_suite.do_bench(lambda: naive_softmax(x), quantiles=quantiles,
154-
n_warmup=10, n_repeat=10)
157+
n_warmup=n_warmup, n_repeat=10)
155158

156159
elif provider == "xetla":
157160
name = f"softmax_shape_{M}_{N}"
@@ -160,7 +163,7 @@ def benchmark(M, N, provider):
160163
xetla_fn = lambda: func(x, out, 0)
161164
torch_fn = lambda: torch.softmax(x, axis=-1)
162165
# benchmark_suite.assert_close(xetla_fn, torch_fn, err_msg="xetla to torch")
163-
_, min_ms, max_ms, mean, cv = benchmark_suite.do_bench(xetla_fn, quantiles=quantiles, n_warmup=10,
166+
_, min_ms, max_ms, mean, cv = benchmark_suite.do_bench(xetla_fn, quantiles=quantiles, n_warmup=n_warmup,
164167
n_repeat=10)
165168

166169
elif provider == "onednn":
@@ -170,7 +173,7 @@ def benchmark(M, N, provider):
170173
onednn_fn = lambda: func(M, N, x, out, 1)
171174
torch_fn = lambda: torch.softmax(x, axis=-1)
172175
benchmark_suite.assert_close(onednn_fn, torch_fn, err_msg="onednn to torch")
173-
_, min_ms, max_ms, mean, cv = benchmark_suite.do_bench(onednn_fn, quantiles=quantiles, n_warmup=10,
176+
_, min_ms, max_ms, mean, cv = benchmark_suite.do_bench(onednn_fn, quantiles=quantiles, n_warmup=n_warmup,
174177
n_repeat=10)
175178

176179
else:

benchmarks/triton_kernels_benchmark/gemm_benchmark.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,8 @@ def get_benchmark(
340340
args={},
341341
))
342342
def benchmark(B, M, N, K, provider):
343+
# Maximum across onednn=600, triton=800, xetla=10, cutlass=600
344+
n_warmup = 800
343345
a_shape, b_shape = get_shapes(B, M, N, K, transpose_a=transpose_a, transpose_b=transpose_b)
344346

345347
torch.manual_seed(0)
@@ -359,7 +361,7 @@ def benchmark(B, M, N, K, provider):
359361
if provider == 'onednn':
360362
_, min_ms, max_ms, mean_ms, cv = benchmark_suite.do_bench(
361363
lambda: torch.matmul(torch_a, torch_b),
362-
n_warmup=10,
364+
n_warmup=n_warmup,
363365
n_repeat=10,
364366
quantiles=quantiles,
365367
)
@@ -387,7 +389,7 @@ def benchmark(B, M, N, K, provider):
387389
benchmark_suite.assert_close(triton_fn, torch_fn, atol=1e-4, rtol=rtol, err_msg='triton to torch')
388390
_, min_ms, max_ms, mean_ms, cv = benchmark_suite.do_bench(
389391
triton_fn,
390-
n_warmup=10,
392+
n_warmup=n_warmup,
391393
n_repeat=10,
392394
quantiles=quantiles,
393395
)
@@ -421,7 +423,7 @@ def xetla_func_with_acc_allocation():
421423
# benchmark_suite.assert_close(xetla_fn, torch_fn, atol=1e-4, rtol=1.0, err_msg='xetla to torch')
422424
_, min_ms, max_ms, mean_ms, cv = benchmark_suite.do_bench(
423425
xetla_fn,
424-
n_warmup=10,
426+
n_warmup=n_warmup,
425427
n_repeat=10,
426428
quantiles=quantiles,
427429
)
@@ -452,7 +454,7 @@ def cutlass_invoker():
452454
benchmark_suite.assert_close(cutlass_fn, torch_fn, atol=1e-4, rtol=rtol, err_msg='cutlass to torch')
453455
_, min_ms, max_ms, mean_ms, cv = benchmark_suite.do_bench(
454456
cutlass_fn,
455-
n_warmup=10,
457+
n_warmup=n_warmup,
456458
n_repeat=10,
457459
quantiles=quantiles,
458460
)

benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,9 @@ def is_enough_memory(x_val):
315315
args={},
316316
))
317317
def benchmark(B, M, N, K, dtype, provider):
318+
# Maximum across onednn=600, triton=1000
319+
# For onednn and triton: Some configs increase performance with warmup as a step function, but some slowly decrease with saturation. Performance is best at 150-200ms range, but we want stable, not just best
320+
n_warmup = 1000
318321
res_dtype = torch.float32 if dtype.is_floating_point else torch.int32
319322
if dtype.is_floating_point:
320323
rand = lambda shape, dtype: torch.rand(shape, device='xpu', dtype=dtype)
@@ -332,7 +335,7 @@ def benchmark(B, M, N, K, dtype, provider):
332335
quantiles = [0.5, 0.0, 1.0]
333336

334337
if provider == 'onednn':
335-
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b) + d, n_warmup=10,
338+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b) + d, n_warmup=n_warmup,
336339
n_repeat=10, quantiles=quantiles)
337340
elif provider == 'triton':
338341
assert len(a.shape) == len(b.shape), 'Incompatible sizes'
@@ -353,7 +356,7 @@ def benchmark(B, M, N, K, dtype, provider):
353356
[1, 512, 8192, 32768], [4, 32768, 4096, 128]]:
354357
# torch int8 matmul on GPU is not supported. only check a few int8 shapes to reduce runtime
355358
benchmark_suit.assert_close(triton_fn, torch_fn, atol=1e-4, rtol=rtol, err_msg='triton to torch')
356-
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
359+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=n_warmup, n_repeat=10,
357360
quantiles=quantiles)
358361
else:
359362
raise NotImplementedError(f'Unsupported provider {provider}')

0 commit comments

Comments
 (0)