Skip to content

Commit e631d76

Browse files
drisspgpytorchmergebot
authored andcommitted
[Flex] Changing how bwd configs are setup and updating default b200 config (pytorch#163318)
```Shell Up to 4x perf boost 🔝 Top 5 Performance Differences (by absolute %): shape: (5, 7) ┌───────────┬────────────────┬────────────────────────────────┬───────────────────┬─────────────────────────────┬─────────────────────────────────┬────────────┐ │ attn_type ┆ dtype ┆ shape(B,Hq,M,Hkv,N,D) ┆ TFlops BWD (base) ┆ TFlops BWD (better_configs) ┆ better_configs_speedup_over_ba… ┆ pct_delta │ │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ │ str ┆ str ┆ str ┆ f64 ┆ f64 ┆ f64 ┆ f64 │ ╞═══════════╪════════════════╪════════════════════════════════╪═══════════════════╪═════════════════════════════╪═════════════════════════════════╪════════════╡ │ noop ┆ torch.bfloat16 ┆ (4, 16, 32768, 16, 32768, 128) ┆ 124.775035 ┆ 532.580435 ┆ 4.268325 ┆ 326.832527 │ │ noop ┆ torch.bfloat16 ┆ (4, 16, 16384, 16, 16384, 128) ┆ 124.494557 ┆ 519.798488 ┆ 4.175271 ┆ 317.527078 │ │ causal ┆ torch.bfloat16 ┆ (4, 16, 32768, 16, 32768, 128) ┆ 123.984189 ┆ 512.877391 ┆ 4.136635 ┆ 313.663544 │ │ noop ┆ torch.bfloat16 ┆ (4, 16, 8192, 16, 8192, 128) ┆ 122.827725 ┆ 496.195958 ┆ 4.039772 ┆ 303.977164 │ │ causal ┆ torch.bfloat16 ┆ (4, 16, 16384, 16, 16384, 128) ┆ 123.826738 ┆ 484.244647 ┆ 3.910663 ┆ 291.066303 │ └───────────┴────────────────┴────────────────────────────────┴───────────────────┴─────────────────────────────┴─────────────────────────────────┴────────────┘ 🔺 Top 5 Cases Where better_configs (change) is Faster than base (baseline): shape: (5, 7) ┌───────────┬────────────────┬────────────────────────────────┬───────────────────┬─────────────────────────────┬─────────────────────────────────┬────────────┐ │ attn_type ┆ dtype ┆ shape(B,Hq,M,Hkv,N,D) ┆ TFlops BWD (base) ┆ TFlops BWD (better_configs) ┆ better_configs_speedup_over_ba… ┆ pct_delta │ │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ │ str ┆ str ┆ str ┆ f64 ┆ f64 ┆ f64 ┆ f64 │ ╞═══════════╪════════════════╪════════════════════════════════╪═══════════════════╪═════════════════════════════╪═════════════════════════════════╪════════════╡ │ noop ┆ torch.bfloat16 ┆ (4, 16, 32768, 16, 32768, 128) ┆ 124.775035 ┆ 532.580435 ┆ 4.268325 ┆ 326.832527 │ │ noop ┆ torch.bfloat16 ┆ (4, 16, 16384, 16, 16384, 128) ┆ 124.494557 ┆ 519.798488 ┆ 4.175271 ┆ 317.527078 │ │ causal ┆ torch.bfloat16 ┆ (4, 16, 32768, 16, 32768, 128) ┆ 123.984189 ┆ 512.877391 ┆ 4.136635 ┆ 313.663544 │ │ noop ┆ torch.bfloat16 ┆ (4, 16, 8192, 16, 8192, 128) ┆ 122.827725 ┆ 496.195958 ┆ 4.039772 ┆ 303.977164 │ │ causal ┆ torch.bfloat16 ┆ (4, 16, 16384, 16, 16384, 128) ┆ 123.826738 ┆ 484.244647 ┆ 3.910663 ┆ 291.066303 │ └───────────┴────────────────┴────────────────────────────────┴───────────────────┴─────────────────────────────┴─────────────────────────────────┴────────────┘ 🔻 Top 5 Cases Where better_configs (change) is Slower than base (baseline): shape: (5, 7) ┌───────────────┬────────────────┬───────────────────────────────┬───────────────────┬─────────────────────────────┬─────────────────────────────────┬───────────┐ │ attn_type ┆ dtype ┆ shape(B,Hq,M,Hkv,N,D) ┆ TFlops BWD (base) ┆ TFlops BWD (better_configs) ┆ better_configs_speedup_over_ba… ┆ pct_delta │ │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ │ str ┆ str ┆ str ┆ f64 ┆ f64 ┆ f64 ┆ f64 │ ╞═══════════════╪════════════════╪═══════════════════════════════╪═══════════════════╪═════════════════════════════╪═════════════════════════════════╪═══════════╡ │ document_mask ┆ torch.bfloat16 ┆ (4, 16, 8192, 16, 8192, 128) ┆ 267.502004 ┆ 250.728732 ┆ 0.937297 ┆ -6.270335 │ │ document_mask ┆ torch.bfloat16 ┆ (4, 16, 8192, 4, 8192, 128) ┆ 248.510516 ┆ 235.210874 ┆ 0.946483 ┆ -5.351742 │ │ document_mask ┆ torch.bfloat16 ┆ (4, 16, 16384, 4, 16384, 128) ┆ 282.856295 ┆ 271.806926 ┆ 0.960936 ┆ -3.906354 │ │ document_mask ┆ torch.bfloat16 ┆ (4, 16, 8192, 16, 8192, 64) ┆ 282.212695 ┆ 280.519092 ┆ 0.993999 ┆ -0.600116 │ │ document_mask ┆ torch.bfloat16 ┆ (4, 16, 32768, 4, 32768, 128) ┆ 295.864073 ┆ 294.477894 ┆ 0.995315 ┆ -0.468519 │ └───────────────┴────────────────┴───────────────────────────────┴───────────────────┴─────────────────────────────┴─────────────────────────────────┴───────────┘ 📊 Performance Summary: ============================================================ Baseline: base Change: better_configs Geometric Mean Speedup (change over baseline): 1.9954x Geometric Mean % Change: +99.54% Median Speedup (change over baseline): 2.1590x Speedup Std Dev: 0.9800 Valid Comparisons: 60/60 ``` Pull Request resolved: pytorch#163318 Approved by: https://github.com/BoyuanFeng
1 parent f8f230a commit e631d76

File tree

2 files changed

+169
-76
lines changed

2 files changed

+169
-76
lines changed

torch/_inductor/kernel/flex/flex_attention.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
# mypy: allow-untyped-defs
22
"""Triton Implementation of the flex_attention Kernel"""
33

4+
from __future__ import annotations
5+
46
import logging
57
import math
68
from collections.abc import Sequence
79
from dataclasses import dataclass
8-
from typing import Any, Optional, Union
10+
from typing import Any, Optional, TYPE_CHECKING, Union
911

1012
import sympy
1113

@@ -35,6 +37,10 @@
3537
from .flex_decoding import _use_flex_decoding, create_flex_decoding_kernel
3638

3739

40+
if TYPE_CHECKING:
41+
from ...template_heuristics.triton import FlexBwDConfig, FlexConfig
42+
43+
3844
log = logging.getLogger(__name__)
3945
aten = torch.ops.aten
4046
Expr = sympy.Expr
@@ -279,7 +285,7 @@ def flex_attention(
279285

280286
dtype = query.get_dtype()
281287
head_dim = V.graph.sizevars.guard_int(query.get_size()[-1])
282-
configs = V.choices.get_flex_attention_fwd_configs(
288+
configs: list[FlexConfig] = V.choices.get_flex_attention_fwd_configs(
283289
head_dim, dtype, query.get_device().type
284290
)
285291

@@ -719,20 +725,21 @@ def flex_attention_backward(*args, **kwargs):
719725

720726
dtype = query.get_dtype()
721727
head_dim = V.graph.sizevars.guard_int(query.get_size()[-1])
722-
configs = V.choices.get_flex_attention_bwd_configs(
728+
configs: list[FlexBwDConfig] = V.choices.get_flex_attention_bwd_configs(
723729
head_dim, dtype, query.get_device().type
724730
)
725731

726732
# Default config for warp specialization
727733
num_consumer_groups, num_buffers_warp_spec = 0, 0
728734

729735
original_kernel_options = kernel_options.copy()
736+
730737
for conf in configs:
731738
if (
732-
SPARSE_KV_BLOCK_SIZE % conf.block_m != 0
733-
or SPARSE_Q_BLOCK_SIZE % conf.block_m != 0
734-
or SPARSE_KV_BLOCK_SIZE % conf.block_n != 0
735-
or SPARSE_Q_BLOCK_SIZE % conf.block_n != 0
739+
SPARSE_KV_BLOCK_SIZE % conf.block_n1 != 0
740+
or SPARSE_Q_BLOCK_SIZE % conf.block_m1 != 0
741+
or SPARSE_KV_BLOCK_SIZE % conf.block_n2 != 0
742+
or SPARSE_Q_BLOCK_SIZE % conf.block_m2 != 0
736743
):
737744
continue
738745

@@ -755,10 +762,10 @@ def flex_attention_backward(*args, **kwargs):
755762
"num_buffers_warp_spec", num_buffers_warp_spec
756763
)
757764

758-
cur_kernel_options.setdefault("BLOCK_M1", conf.block_m)
759-
cur_kernel_options.setdefault("BLOCK_N1", conf.block_n)
760-
cur_kernel_options.setdefault("BLOCK_M2", conf.block_n)
761-
cur_kernel_options.setdefault("BLOCK_N2", conf.block_m)
765+
cur_kernel_options.setdefault("BLOCK_M1", conf.block_m1)
766+
cur_kernel_options.setdefault("BLOCK_N1", conf.block_n1)
767+
cur_kernel_options.setdefault("BLOCK_M2", conf.block_m2)
768+
cur_kernel_options.setdefault("BLOCK_N2", conf.block_n2)
762769

763770
# Blocksparse options
764771
cur_kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE)

0 commit comments

Comments
 (0)