Skip to content

Commit eb99fa1

Browse files
authored
[release/2.7] Fix for flex attention tuning (#2589)
Bug fix after #2392 landed Issue caused from bad merge conflict resolution, resulting in the code using an outdated API. > torch._inductor.exc.LoweringException: NameError: name '_get_default_config_bwd' is not defined target: flex_attention_backward Models now run to completion
1 parent 59925f5 commit eb99fa1

File tree

1 file changed

+0
-14
lines changed

1 file changed

+0
-14
lines changed

torch/_inductor/kernel/flex_attention.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2476,20 +2476,6 @@ def flex_attention_backward(*args, **kwargs):
24762476
SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE)
24772477

24782478
choices: list[Any] = []
2479-
configs: list[tuple[int, int, int, int]] = []
2480-
configs.append(_get_default_config_bwd(query))
2481-
if config.max_autotune:
2482-
num_stages_list = [1, 3, 4, 5] if torch.version.hip is None else [1]
2483-
configs.extend(
2484-
[
2485-
(BLOCK1, BLOCK2, w, s)
2486-
for BLOCK1 in [32, 64]
2487-
for BLOCK2 in [32, 64, 128]
2488-
for w in ([4, 8] if BLOCK1 >= 128 or BLOCK2 >= 128 else [4])
2489-
for s in num_stages_list
2490-
if BLOCK2 % BLOCK1 == 0
2491-
]
2492-
)
24932479

24942480
dtype = query.get_dtype()
24952481
head_dim = V.graph.sizevars.evaluate_static_shape(query.get_size()[-1])

0 commit comments

Comments
 (0)