File tree Expand file tree Collapse file tree 2 files changed +33
-0
lines changed Expand file tree Collapse file tree 2 files changed +33
-0
lines changed Original file line number Diff line number Diff line change @@ -38,3 +38,4 @@ echo "Applying PyTorch patches in $REPO_ROOT"
3838apply_patch ./patch/flex_attn_143553.patch
3939apply_patch pytorch_fp64.patch
4040apply_patch ./patch/pytorch_global_scratch.patch
41+ apply_patch ./patch/flex_decoding.patch
Original file line number Diff line number Diff line change 1+ Subject: [PATCH] Remove the min number constrain on block M in flex_decoding.py
2+ ---
3+ Index: torch/_inductor/kernel/flex_decoding.py
4+ IDEA additional info:
5+ Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
6+ <+>UTF-8
7+ ===================================================================
8+ diff --git a/torch/_inductor/kernel/flex_decoding.py b/torch/_inductor/kernel/flex_decoding.py
9+ --- a/torch/_inductor/kernel/flex_decoding.py (revision 5329b5b5623af429a64cc7e679b1fa03f47225d8)
10+ +++ b/torch/_inductor/kernel/flex_decoding.py (revision beef69e50627af4d6009bcb9c9f758fa9f4aa81c)
11+ @@ -457,15 +457,12 @@
12+ kernel_options.setdefault(
13+ "BLOCK_M",
14+ (
15+ - max(
16+ - next_power_of_2(
17+ - V.graph.sizevars.size_hint(
18+ - seq_len_q,
19+ - fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type]
20+ - )
21+ - * gqa_shared_heads
22+ - ),
23+ - 8,
24+ + next_power_of_2(
25+ + V.graph.sizevars.size_hint(
26+ + seq_len_q,
27+ + fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type]
28+ + )
29+ + * gqa_shared_heads
30+ )
31+ ),
32+ )
You can’t perform that action at this time.
0 commit comments