Skip to content

Commit 9153bd0

Browse files
authored
Add flex decoding patch. (#4766)
Signed-off-by: Lu,Chengjun <[email protected]>
1 parent 7cf3c0b commit 9153bd0

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

scripts/patch-pytorch.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,4 @@ echo "Applying PyTorch patches in $REPO_ROOT"
3838
apply_patch ./patch/flex_attn_143553.patch
3939
apply_patch pytorch_fp64.patch
4040
apply_patch ./patch/pytorch_global_scratch.patch
41+
apply_patch ./patch/flex_decoding.patch

scripts/patch/flex_decoding.patch

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
Subject: [PATCH] Remove the min number constrain on block M in flex_decoding.py
2+
---
3+
Index: torch/_inductor/kernel/flex_decoding.py
4+
IDEA additional info:
5+
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
6+
<+>UTF-8
7+
===================================================================
8+
diff --git a/torch/_inductor/kernel/flex_decoding.py b/torch/_inductor/kernel/flex_decoding.py
9+
--- a/torch/_inductor/kernel/flex_decoding.py (revision 5329b5b5623af429a64cc7e679b1fa03f47225d8)
10+
+++ b/torch/_inductor/kernel/flex_decoding.py (revision beef69e50627af4d6009bcb9c9f758fa9f4aa81c)
11+
@@ -457,15 +457,12 @@
12+
kernel_options.setdefault(
13+
"BLOCK_M",
14+
(
15+
- max(
16+
- next_power_of_2(
17+
- V.graph.sizevars.size_hint(
18+
- seq_len_q,
19+
- fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type]
20+
- )
21+
- * gqa_shared_heads
22+
- ),
23+
- 8,
24+
+ next_power_of_2(
25+
+ V.graph.sizevars.size_hint(
26+
+ seq_len_q,
27+
+ fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type]
28+
+ )
29+
+ * gqa_shared_heads
30+
)
31+
),
32+
)

0 commit comments

Comments
 (0)