Skip to content

Commit d1c90a0

Browse files
authored
[SWDEV-538312] Fix num_stages=8 configs for flex attention (#2277)
num_stages==8 configs are always skipped causing breakages Example error: ``` torch._inductor.exc.LoweringException: NoValidChoicesError: No choices to select, please consider adding ATEN into max_autotune_gemm_backends config (defined in torch/_inductor/config.py) to allow at least one choice. target: flex_attention_backward ```
1 parent eb37e58 commit d1c90a0

File tree

1 file changed

+0
-3
lines changed

1 file changed

+0
-3
lines changed

torch/_inductor/kernel/flex_attention.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2273,9 +2273,6 @@ def flex_attention_backward(*args, **kwargs):
22732273
or SPARSE_Q_BLOCK_SIZE % BLOCK2 != 0
22742274
):
22752275
continue
2276-
if num_warps == 8:
2277-
# Working around https://github.com/pytorch/pytorch/issues/141603
2278-
continue
22792276

22802277
# Performance tuning
22812278
cur_kernel_options = original_kernel_options.copy()

0 commit comments

Comments
 (0)