Skip to content

Commit 8255e76

Browse files
[FlexAttn] Set USE_TMA explicitly (#5075)
This PR removes the `use_tma.patch` patch, as there is no plan to upstream it to PyTorch. Signed-off-by: Whitney Tsang <[email protected]>
1 parent 42793a2 commit 8255e76

File tree

3 files changed

+1
-38
lines changed

3 files changed

+1
-38
lines changed

benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def benchmark(Z, H_q, H_kv, N_CTX_q, N_CTX_kv, D_HEAD_qk, D_HEAD_v, MODE, provid
156156
device=DEVICE)
157157

158158
elif provider == 'triton':
159-
kernel_options = {'BLOCKS_ARE_CONTIGUOUS': True}
159+
kernel_options = {'BLOCKS_ARE_CONTIGUOUS': True, 'USE_TMA': True}
160160
triton_fn = lambda: compiled_flex_attention(q, k, v, block_mask=block_mask, scale=sm_scale, enable_gqa=(
161161
not H_q == H_kv), kernel_options=kernel_options)
162162
if MODE == 'bwd':

scripts/patch-pytorch.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,3 @@ apply_patch() {
3535
echo "Applying PyTorch patches in $REPO_ROOT"
3636

3737
# put your patch applies here
38-
apply_patch ./patch/use_tma.patch

scripts/patch/use_tma.patch

Lines changed: 0 additions & 36 deletions
This file was deleted.

0 commit comments

Comments
 (0)