File tree Expand file tree Collapse file tree 2 files changed +33
-0
lines changed Expand file tree Collapse file tree 2 files changed +33
-0
lines changed Original file line number Diff line number Diff line change @@ -38,3 +38,4 @@ echo "Applying PyTorch patches in $REPO_ROOT"
38
38
apply_patch ./patch/flex_attn_143553.patch
39
39
apply_patch pytorch_fp64.patch
40
40
apply_patch ./patch/pytorch_global_scratch.patch
41
+ apply_patch ./patch/flex_decoding.patch
Original file line number Diff line number Diff line change
1
+ Subject: [PATCH] Remove the min number constrain on block M in flex_decoding.py
2
+ ---
3
+ Index: torch/_inductor/kernel/flex_decoding.py
4
+ IDEA additional info:
5
+ Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
6
+ <+>UTF-8
7
+ ===================================================================
8
+ diff --git a/torch/_inductor/kernel/flex_decoding.py b/torch/_inductor/kernel/flex_decoding.py
9
+ --- a/torch/_inductor/kernel/flex_decoding.py (revision 5329b5b5623af429a64cc7e679b1fa03f47225d8)
10
+ +++ b/torch/_inductor/kernel/flex_decoding.py (revision beef69e50627af4d6009bcb9c9f758fa9f4aa81c)
11
+ @@ -457,15 +457,12 @@
12
+ kernel_options.setdefault(
13
+ "BLOCK_M",
14
+ (
15
+ - max(
16
+ - next_power_of_2(
17
+ - V.graph.sizevars.size_hint(
18
+ - seq_len_q,
19
+ - fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type]
20
+ - )
21
+ - * gqa_shared_heads
22
+ - ),
23
+ - 8,
24
+ + next_power_of_2(
25
+ + V.graph.sizevars.size_hint(
26
+ + seq_len_q,
27
+ + fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type]
28
+ + )
29
+ + * gqa_shared_heads
30
+ )
31
+ ),
32
+ )
You can’t perform that action at this time.
0 commit comments