Skip to content

Commit fb2fff8

Browse files
authored
[CI][tutorials] Split tutorial 06-fused-attention.py run in build process to accelerate runtime (#4993)
Closes #4948 Implementation details: 1. Previously `mxfp` suite in the workflow was responsible for running tests and running tutorial 6. I moved tests to `rest` suite .mxfp test takes ~1min, so this move won't affect runtime significantly. I separated tutorial 6 into 3 separate parts (total compute time ~45mins -> ~25mins for the slowest part after the split). 2. I added parsing of env variables to tutorial 6 to specify config to run. 3. I modified `test-triton.sh` and `pytest-utils.sh` to support tutorial selection based on input arguments. For that I moved select from file logic from `pytest-utils.sh` to `test-triton.sh`. I think that's reasonable, because that way we can clearly see how we skip some benchmarks. 4. Further investigation of tutorial 6 performance shows that the majority of time is spent on forward kernel autotuning, that also happens for backward phase due to a single forward call for shape inference. That autotuning is dtype specific (separate for FP8 & FP16) and FP8 is much slower. Moreover backward mode of FP8 uses FP16 forward and backward pass. So to optimize autotune runs we have 3 configs: 64 heads, 128 heads with autotune for FP8 types (only FWD FP8), 128 heads with autotune for FP16 types (FWD for FP16 & BWD for FP8,FP16). Alternative implementation could be something like setting tutorial 6 in the file with tutorial list, setting env variables and then running tutorials. It would require less changes, but it would make new spilt functionality non-obvious and difficult to use if I want to run specific subconfig. With this implementation I keep things explicit and obvious.
1 parent b137b65 commit fb2fff8

File tree

4 files changed

+120
-31
lines changed

4 files changed

+120
-31
lines changed

.github/workflows/build-test-reusable.yml

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -195,9 +195,11 @@ jobs:
195195
matrix:
196196
suite:
197197
- minicore
198-
- mxfp
199198
- scaled_dot
200199
- rest
200+
- tutorial-fa-64
201+
- tutorial-fa-128-fwdfp8
202+
- tutorial-fa-128-nofwdfp8
201203
timeout-minutes: 720
202204
runs-on: ${{ fromJson(inputs.runner_label && format('["linux", "{0}"]', inputs.runner_label) || format('["linux", "{0}", "{1}", "{2}"]', inputs.device, inputs.driver_version, inputs.runner_version)) }}
203205
defaults:
@@ -295,7 +297,7 @@ jobs:
295297
${{ env.TRITON_TEST_CMD }} --minicore
296298
297299
- name: Run mxfp tests
298-
if: matrix.suite == 'mxfp'
300+
if: matrix.suite == 'rest'
299301
run: |
300302
${{ env.TRITON_TEST_CMD }} --mxfp
301303
@@ -309,15 +311,7 @@ jobs:
309311
run: |
310312
${{ env.TRITON_TEST_CMD }} --interpreter
311313
312-
# FIXME: make sure new tutorials are added to one of the groups (mxfp, scaled_dot, rest)
313-
314-
- name: Select tutorials to run (mxfp)
315-
if: matrix.suite == 'mxfp'
316-
run: |
317-
cat <<EOF | tee tutorials.txt
318-
06-fused-attention
319-
EOF
320-
314+
# FIXME: make sure new tutorials are added to one of the groups (scaled_dot, rest, tutorial-faX)
321315
- name: Select tutorials to run (scaled_dot)
322316
if: matrix.suite == 'scaled_dot'
323317
run: |
@@ -341,10 +335,16 @@ jobs:
341335
EOF
342336
343337
- name: Run Tutorials
344-
if: matrix.suite == 'mxfp' || matrix.suite == 'scaled_dot' || matrix.suite == 'rest'
338+
if: matrix.suite == 'scaled_dot' || matrix.suite == 'rest'
345339
run: |
346340
${{ env.TRITON_TEST_CMD }} --select-from-file tutorials.txt --tutorial
347341
342+
# Run 06-fused-attention.py separately, because it is split into 3 configs
343+
- name: Run Flash Attention tutorials
344+
if: matrix.suite == 'tutorial-fa-64' || matrix.suite == 'tutorial-fa-128-fwdfp8' || matrix.suite == 'tutorial-fa-128-nofwdfp8'
345+
run: |
346+
${{ env.TRITON_TEST_CMD }} "--${{ matrix.suite }}"
347+
348348
- name: Install transformers
349349
if: matrix.suite == 'rest'
350350
run: |

python/tutorials/06-fused-attention.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,47 @@
2424
DEVICE = triton.runtime.driver.active.get_active_torch_device()
2525

2626

27+
def parse_config():
28+
head_dim_txt = os.getenv("HEAD_DIM", "")
29+
print("HEAD_DIM", head_dim_txt)
30+
31+
head_dims = [64, 128]
32+
try:
33+
head_dim = int(head_dim_txt)
34+
head_dims = [head_dim]
35+
except ValueError:
36+
pass
37+
38+
# With FWD_FP8_ONLY we will only run forward with FP8 and will not run backward
39+
# With FWD_FP8_SKIP we will skip forward with FP8, but will run forward with FP16 and backward with FP8 & FP16
40+
# The reason is that currently the slowest step is kernel autotuning, which is only done for forward pass
41+
# The slowest autotune is FP8, which is several times slower than FP16
42+
# However, backward pass currently involves calling forward pass, hence, has the same slow time
43+
# But, backward pass for FP8 is actually just backward pass for FP16, there is no difference, so it uses FP16 forward tuning.
44+
# So from a workload perspective the best strategy for parallel execution is to run separately
45+
# 1. Forward pass with FP8, which will trigger autotune(FP8-FWD)
46+
# 2. Forward pass with FP16 and backward with FP8 & FP16, which will trigger only autotune(FP16-FWD)
47+
fwd_fp8_only_txt = os.getenv("FWD_FP8_ONLY", "0")
48+
fwd_fp8_skip_txt = os.getenv("FWD_FP8_SKIP", "0")
49+
print("FWD_FP8_ONLY", fwd_fp8_only_txt)
50+
print("FWD_FP8_SKIP", fwd_fp8_skip_txt)
51+
if fwd_fp8_only_txt == "1":
52+
fwd_dtypes = ['fp8']
53+
modes = ['fwd']
54+
elif fwd_fp8_skip_txt == "1":
55+
fwd_dtypes = ['fp16']
56+
modes = ['fwd', 'bwd']
57+
else:
58+
fwd_dtypes = ['fp8', 'fp16']
59+
modes = ['fwd', 'bwd']
60+
61+
return head_dims, fwd_dtypes, modes
62+
63+
64+
HEAD_DIMS, FWD_DTYPES, MODES = parse_config()
65+
print("HEAD_DIM_OPTIONS", HEAD_DIMS, "FWD_DTYPES", FWD_DTYPES, "MODES", MODES)
66+
67+
2768
def is_hip():
2869
return triton.runtime.driver.active.get_current_target().backend == "hip"
2970

@@ -705,15 +746,23 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, warp_specialize, mode, provider, dtyp
705746
# Enable warpspec for causal fwd on Hopper
706747
enable_ws = mode == "fwd" and (is_blackwell() or (is_hopper() and not causal))
707748
for warp_specialize in [False, True] if enable_ws else [False]:
749+
750+
if HEAD_DIM not in HEAD_DIMS or mode not in MODES:
751+
continue
752+
include_fp8 = mode != 'fwd' or 'fp8' in FWD_DTYPES
753+
include_fp16 = mode != 'fwd' or 'fp16' in FWD_DTYPES
754+
708755
configs.append(
709756
triton.testing.Benchmark(
710757
x_names=["N_CTX"],
711758
x_vals=[2**i for i in range(10, 15)],
712759
line_arg="provider",
713-
line_vals=["triton-fp16"] + (["triton-fp8"] if TORCH_HAS_FP8 else []) +
714-
(["flash"] if HAS_FLASH else []),
715-
line_names=["Triton [FP16]"] + (["Triton [FP8]"] if TORCH_HAS_FP8 else []) +
716-
(["Flash-2"] if HAS_FLASH else []),
760+
line_vals=((["triton-fp16"] if include_fp16 else []) +
761+
(["triton-fp8"] if TORCH_HAS_FP8 and include_fp8 else []) +
762+
(["flash"] if HAS_FLASH else [])),
763+
line_names=((["Triton [FP16]"] if include_fp16 else []) +
764+
(["Triton [FP8]"] if TORCH_HAS_FP8 and include_fp8 else []) +
765+
(["Flash-2"] if HAS_FLASH else [])),
717766
styles=[("red", "-"), ("blue", "-"), ("green", "-")],
718767
ylabel="TFLOPS",
719768
plot_name=

scripts/pytest-utils.sh

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,6 @@ pytest() {
5858
}
5959

6060
run_tutorial_test() {
61-
if [[ -f $TRITON_TEST_SELECTFILE ]] && ! grep -qF "$1" "$TRITON_TEST_SELECTFILE"; then
62-
return
63-
fi
64-
6561
echo
6662
echo "****** Running $1 test ******"
6763
echo

scripts/test-triton.sh

Lines changed: 55 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ TEST:
1818
--softmax
1919
--gemm
2020
--flash-attention
21+
- tutorial-fa-64
22+
- tutorial-fa-128-fwdfp8
23+
- tutorial-fa-128-nofwdfp8
2124
--flex-attention
2225
--instrumentation
2326
--inductor
@@ -110,6 +113,27 @@ while (( $# != 0 )); do
110113
TEST_DEFAULT=false
111114
shift
112115
;;
116+
--tutorial-fa-64)
117+
TEST_TUTORIAL=true
118+
TEST_TUTORIAL_FA=true
119+
FA_CONFIG="HEAD_DIM=64"
120+
TEST_DEFAULT=false
121+
shift
122+
;;
123+
--tutorial-fa-128-fwdfp8)
124+
TEST_TUTORIAL=true
125+
TEST_TUTORIAL_FA=true
126+
FA_CONFIG="HEAD_DIM=128 FWD_FP8_ONLY=1"
127+
TEST_DEFAULT=false
128+
shift
129+
;;
130+
--tutorial-fa-128-nofwdfp8)
131+
TEST_TUTORIAL=true
132+
TEST_TUTORIAL_FA=true
133+
FA_CONFIG="HEAD_DIM=128 FWD_FP8_SKIP=1"
134+
TEST_DEFAULT=false
135+
shift
136+
;;
113137
--microbench)
114138
TEST_MICRO_BENCHMARKS=true
115139
TEST_DEFAULT=false
@@ -371,17 +395,37 @@ run_tutorial_tests() {
371395
python -m pip install matplotlib pandas tabulate -q
372396
cd $TRITON_PROJ/python/tutorials
373397

374-
run_tutorial_test "01-vector-add"
375-
run_tutorial_test "02-fused-softmax"
376-
run_tutorial_test "03-matrix-multiplication"
377-
run_tutorial_test "04-low-memory-dropout"
378-
run_tutorial_test "05-layer-norm"
379-
run_tutorial_test "06-fused-attention"
380-
run_tutorial_test "07-extern-functions"
381-
run_tutorial_test "08-grouped-gemm"
382-
run_tutorial_test "09-persistent-matmul"
383-
run_tutorial_test "10-experimental-block-pointer"
384-
run_tutorial_test "10i-experimental-block-pointer"
398+
tutorials=(
399+
"01-vector-add"
400+
"02-fused-softmax"
401+
"03-matrix-multiplication"
402+
"04-low-memory-dropout"
403+
"05-layer-norm"
404+
"06-fused-attention"
405+
"07-extern-functions"
406+
"08-grouped-gemm"
407+
"09-persistent-matmul"
408+
"10-experimental-block-pointer"
409+
"10i-experimental-block-pointer"
410+
)
411+
if [ "${TEST_TUTORIAL_FA:-false}" = true ]; then
412+
tutorials=(
413+
"06-fused-attention"
414+
)
415+
416+
if [ -n "${FA_CONFIG:-}" ]; then
417+
# Containst specific config for Fused attention tutorial
418+
export $FA_CONFIG
419+
fi
420+
fi
421+
422+
for tutorial in "${tutorials[@]}"; do
423+
if [[ -f $TRITON_TEST_SELECTFILE ]] && ! grep -qF "$tutorial" "$TRITON_TEST_SELECTFILE"; then
424+
continue
425+
fi
426+
427+
run_tutorial_test "$tutorial"
428+
done
385429
}
386430

387431
run_microbench_tests() {

0 commit comments

Comments
 (0)