Skip to content

Commit a8f38c3

Browse files
authored
Flex Attention + Packing with BlockMask support (axolotl-ai-cloud#2363)
1 parent e7e0cd9 commit a8f38c3

File tree

7 files changed

+281
-5
lines changed

7 files changed

+281
-5
lines changed

src/axolotl/core/trainer_builder.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -891,7 +891,11 @@ def build_collator(
891891
if "max_length" in kwargs:
892892
kwargs.pop("max_length")
893893
elif use_batch_sampler_collator:
894-
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES or (
894+
if self.cfg.flex_attention:
895+
collator = V2BatchSamplerDataCollatorForSeq2Seq
896+
elif self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
897+
collator = V2BatchSamplerDataCollatorForSeq2Seq
898+
elif (
895899
self.cfg.model_config_type in ["llama"]
896900
and self.cfg.flash_attention is not True
897901
):
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
"""Flex attention monkey patch"""
2+
3+
import torch
4+
import transformers
5+
6+
7+
def patch_flex():
8+
is_torch_2_6 = torch.__version__.startswith("2.6")
9+
is_transformers_below_4_51 = transformers.__version__ < "4.51.0"
10+
11+
if is_torch_2_6 and is_transformers_below_4_51:
12+
from torch.nn.attention.flex_attention import flex_attention
13+
14+
class WrappedFlexAttention:
15+
"""
16+
We are doing a singleton class so that flex attention is compiled once when it's first called.
17+
"""
18+
19+
_instance = None
20+
_is_flex_compiled = False
21+
_compiled_flex_attention = None
22+
23+
def __new__(cls, *args, **kwargs):
24+
if cls._instance is None:
25+
# Create a new instance if one doesn't already exist
26+
cls._instance = super().__new__(cls)
27+
return cls._instance
28+
29+
@torch.compiler.disable(recursive=False)
30+
def __init__(self):
31+
"""
32+
Initialize or update the singleton instance.
33+
"""
34+
if not self._is_flex_compiled:
35+
self._compiled_flex_attention = torch.compile(
36+
flex_attention,
37+
dynamic=False,
38+
mode="max-autotune-no-cudagraphs",
39+
fullgraph=True,
40+
)
41+
self._is_flex_compiled = True
42+
43+
def __call__(self):
44+
return self._compiled_flex_attention
45+
46+
transformers.integrations.flex_attention.WrappedFlexAttention = (
47+
WrappedFlexAttention
48+
)

src/axolotl/utils/models.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,7 @@ def apply_patches(self) -> None:
578578

579579
if (
580580
self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
581-
and self.cfg.flash_attention
581+
and (self.cfg.flash_attention or self.cfg.flex_attention)
582582
and self.cfg.sample_packing
583583
):
584584
if "auto_map" in self.model_config:
@@ -884,7 +884,16 @@ def set_attention_config(self) -> None:
884884
"""
885885
sample packing uses custom FA2 patch
886886
"""
887-
if self.cfg.flash_attention:
887+
if self.cfg.flex_attention:
888+
self.model_kwargs["attn_implementation"] = "flex_attention"
889+
self.model_config._attn_implementation = ( # pylint: disable=protected-access
890+
"flex_attention"
891+
)
892+
from axolotl.monkeypatch.attention.flex_attn import patch_flex
893+
894+
patch_flex()
895+
896+
elif self.cfg.flash_attention:
888897
if not self.cfg.sample_packing and self.cfg.s2_attention:
889898
pass
890899
self.model_kwargs["attn_implementation"] = "flash_attention_2"
@@ -1281,7 +1290,10 @@ def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
12811290
should_convert = (
12821291
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
12831292
# convert them back to fp16/bf16 for flash-attn compatibility.
1284-
((needs_fa2_dtype or self.cfg.flash_attention) and not qlora_fsdp)
1293+
(
1294+
(needs_fa2_dtype or self.cfg.flash_attention or self.cfg.flex_attention)
1295+
and not qlora_fsdp
1296+
)
12851297
or self.cfg.cut_cross_entropy # Cut cross entropy requires embedding layers to be in fp16/bf16 for backward pass
12861298
)
12871299

src/axolotl/utils/schemas/config.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ class AxolotlInputConfig(
223223
xformers_attention: bool | None = None
224224
sdp_attention: bool | None = None
225225
s2_attention: bool | None = None
226+
flex_attention: bool | None = None
226227
flash_attention: bool | None = None
227228
flash_attn_cross_entropy: bool | None = None
228229
flash_attn_rms_norm: bool | None = None
@@ -355,6 +356,22 @@ def datasets_serializer(
355356
return [ds_config.model_dump(exclude_none=True) for ds_config in ds_configs]
356357
return None
357358

359+
@model_validator(mode="before")
360+
@classmethod
361+
def check_attention_fields(cls, data):
362+
fields = (
363+
"xformers_attention",
364+
"sdp_attention",
365+
"s2_attention",
366+
"flash_attention",
367+
"flex_attention",
368+
)
369+
non_empty_count = sum(1 for field in fields if data.get(field))
370+
371+
if non_empty_count > 1:
372+
raise ValueError(f"Only one of {', '.join(fields)} must be set")
373+
return data
374+
358375
@model_validator(mode="before")
359376
@classmethod
360377
def check_batch_size_fields(cls, data):
@@ -1250,6 +1267,24 @@ def check_adopt_torch_version(cls, data):
12501267
)
12511268
return data
12521269

1270+
@model_validator(mode="before")
1271+
@classmethod
1272+
def check_flex_torch_version(cls, data):
1273+
if (data.get("flex_attention") is not None) and (data.get("flex_attention")):
1274+
env_capabilities = data.get("env_capabilities", {})
1275+
torch_version = env_capabilities.get("torch_version")
1276+
1277+
if torch_version is None:
1278+
import torch
1279+
1280+
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
1281+
1282+
if version.parse(torch_version) < version.parse("2.6.0"):
1283+
raise ValueError(
1284+
"Flex attention is not supported on torch version < 2.6.0"
1285+
)
1286+
return data
1287+
12531288
@model_validator(mode="before")
12541289
@classmethod
12551290
def check_torch_compile_auto(cls, data):

tests/e2e/multigpu/solo/test_flex.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
"""
2+
E2E tests for multigpu lora tinyllama
3+
"""
4+
5+
import logging
6+
import os
7+
from pathlib import Path
8+
9+
import pytest
10+
import yaml
11+
from accelerate.test_utils import execute_subprocess_async
12+
from huggingface_hub import snapshot_download
13+
from transformers.testing_utils import get_torch_dist_unique_port
14+
from transformers.utils import is_torch_bf16_gpu_available
15+
16+
from axolotl.utils.dict import DictDefault
17+
18+
from tests.e2e.utils import check_tensorboard, require_torch_2_6_0
19+
20+
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
21+
os.environ["WANDB_DISABLED"] = "true"
22+
23+
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
24+
25+
26+
@pytest.fixture(scope="session", autouse=True)
27+
def download_model():
28+
# download the model
29+
snapshot_download("HuggingFaceTB/SmolLM2-135M")
30+
31+
32+
class TestPackedFlex:
33+
"""
34+
Test case for Packed training of llama models
35+
"""
36+
37+
@require_torch_2_6_0
38+
def test_loss_llama(self, temp_dir):
39+
# pylint: disable=duplicate-code
40+
cfg = DictDefault(
41+
{
42+
"base_model": "HuggingFaceTB/SmolLM2-135M",
43+
"sequence_len": 1024,
44+
"sample_packing": True,
45+
"flex_attention": True,
46+
"val_set_size": 0.0,
47+
"special_tokens": {
48+
"pad_token": "<|endoftext|>",
49+
},
50+
"datasets": [
51+
{
52+
"path": "vicgalle/alpaca-gpt4",
53+
"type": "alpaca",
54+
},
55+
],
56+
"num_epochs": 1,
57+
"micro_batch_size": 2,
58+
"gradient_accumulation_steps": 4,
59+
"output_dir": temp_dir,
60+
"learning_rate": 0.00001,
61+
"optimizer": "adamw_torch_fused",
62+
"lr_scheduler": "cosine",
63+
"max_steps": 5,
64+
"use_tensorboard": True,
65+
"save_strategy": "no",
66+
}
67+
)
68+
if is_torch_bf16_gpu_available():
69+
cfg.bf16 = True
70+
else:
71+
cfg.fp16 = True
72+
73+
# write cfg to yaml file
74+
Path(temp_dir).mkdir(parents=True, exist_ok=True)
75+
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
76+
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
77+
78+
execute_subprocess_async(
79+
[
80+
"axolotl",
81+
"train",
82+
str(Path(temp_dir) / "config.yaml"),
83+
"--num-processes",
84+
"2",
85+
"--main-process-port",
86+
f"{get_torch_dist_unique_port()}",
87+
]
88+
)
89+
90+
check_tensorboard(
91+
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
92+
)

tests/e2e/solo/test_flex.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
"""
2+
E2E tests for packed training w/ flex attention
3+
"""
4+
5+
import logging
6+
import os
7+
import unittest
8+
9+
from transformers.utils import is_torch_bf16_gpu_available
10+
11+
from axolotl.cli.args import TrainerCliArgs
12+
from axolotl.common.datasets import load_datasets
13+
from axolotl.train import train
14+
from axolotl.utils.config import normalize_config, validate_config
15+
from axolotl.utils.dict import DictDefault
16+
17+
from ..utils import check_tensorboard, require_torch_2_6_0, with_temp_dir
18+
19+
LOG = logging.getLogger("axolotl.tests.e2e")
20+
os.environ["WANDB_DISABLED"] = "true"
21+
22+
23+
class TestPackedFlex(unittest.TestCase):
24+
"""
25+
Test case for Packed training of llama models
26+
"""
27+
28+
@require_torch_2_6_0
29+
@with_temp_dir
30+
def test_loss_llama(self, temp_dir):
31+
# pylint: disable=duplicate-code
32+
cfg = DictDefault(
33+
{
34+
"base_model": "HuggingFaceTB/SmolLM2-135M",
35+
"sequence_len": 1024,
36+
"sample_packing": True,
37+
"flex_attention": True,
38+
"val_set_size": 0.0,
39+
"special_tokens": {
40+
"pad_token": "<|endoftext|>",
41+
},
42+
"datasets": [
43+
{
44+
"path": "vicgalle/alpaca-gpt4",
45+
"type": "alpaca",
46+
},
47+
],
48+
"num_epochs": 1,
49+
"micro_batch_size": 2,
50+
"gradient_accumulation_steps": 4,
51+
"output_dir": temp_dir,
52+
"learning_rate": 0.00001,
53+
"optimizer": "adamw_torch_fused",
54+
"lr_scheduler": "cosine",
55+
"max_steps": 5,
56+
"use_tensorboard": True,
57+
}
58+
)
59+
if is_torch_bf16_gpu_available():
60+
cfg.bf16 = True
61+
else:
62+
cfg.fp16 = True
63+
64+
cfg = validate_config(cfg)
65+
normalize_config(cfg)
66+
cli_args = TrainerCliArgs()
67+
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
68+
69+
train(cfg=cfg, dataset_meta=dataset_meta)
70+
71+
check_tensorboard(
72+
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
73+
)

tests/e2e/utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,21 @@ def is_min_2_5_1():
6767
return unittest.skipUnless(is_min_2_5_1(), "test requires torch>=2.5.1")(test_case)
6868

6969

70+
def require_torch_2_6_0(test_case):
71+
"""
72+
Decorator marking a test that requires torch >= 2.6.0
73+
"""
74+
75+
def is_min_2_6_0():
76+
torch_version = version.parse(torch.__version__)
77+
return torch_version >= version.parse("2.6.0")
78+
79+
return unittest.skipUnless(is_min_2_6_0(), "test requires torch>=2.6.0")(test_case)
80+
81+
7082
def require_torch_lt_2_6_0(test_case):
7183
"""
72-
Decorator marking a test that requires torch >= 2.5.1
84+
Decorator marking a test that requires torch < 2.6.0
7385
"""
7486

7587
def is_max_2_6_0():

0 commit comments

Comments
 (0)