Skip to content

Commit ff0fe76

Browse files
authored
xformers attention with packing (axolotl-ai-cloud#2619)
* xformers attention with packing * wire up the patch * fix xformers + packing validation * fix warning * reorder the packing check * fix fp16 / bf16 reset when using fp16 with bf16 auto * fix seq lens calc to drop hanging sequences * handle xformers patch for inference too * fix batch size setter * fix xformers inference * add colab callback to fix inference post train * PR feedback
1 parent 8e4158c commit ff0fe76

File tree

8 files changed

+222
-12
lines changed

8 files changed

+222
-12
lines changed

docs/config.qmd

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,12 @@ load_in_8bit: true
7373
load_in_4bit:
7474

7575
# Use CUDA bf16
76-
bf16: true # bool or 'full' for `bf16_full_eval`. require >=ampere
76+
bf16: true # bool or 'full' for `bf16_full_eval`, or 'auto' for automatic detection. require >=ampere
7777
# Use CUDA fp16
7878
fp16: true
7979
# Use CUDA tf32
8080
tf32: true # require >=ampere
81+
# Note: if bf16 is set to 'auto', and fp16 is set to true, we will prefer the explict fp16 setting
8182

8283
# No AMP (automatic mixed precision)
8384
bfloat16: true # require >=ampere

src/axolotl/core/trainer_builder.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import inspect
2222
import logging
2323
import math
24+
import os
2425
import sys
2526
from abc import abstractmethod
2627
from pathlib import Path
@@ -72,6 +73,7 @@
7273
SaveBetterTransformerModelCallback,
7374
bench_eval_callback_factory,
7475
causal_lm_bench_eval_callback_factory,
76+
colab_inference_post_train_callback,
7577
log_prediction_callback_factory,
7678
)
7779
from axolotl.utils.callbacks.lisa import lisa_callback_factory
@@ -293,6 +295,10 @@ def get_post_trainer_create_callbacks(self, trainer):
293295
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
294296
callbacks.append(lisa_callback_factory(trainer))
295297

298+
if any("COLAB_" in key for key in os.environ):
299+
ColabCallback = colab_inference_post_train_callback(trainer)
300+
callbacks.append(ColabCallback(self.cfg))
301+
296302
callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer))
297303
return callbacks
298304

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
"""
2+
attention module for attention monkeypatches
3+
"""
4+
5+
from transformers.integrations.flash_attention import flash_attention_forward
6+
7+
8+
def patch_xformers_attn_over_fa2():
9+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
10+
11+
from .xformers import xformers_attention_forward
12+
13+
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = xformers_attention_forward
14+
15+
16+
def unpatch_xformers_attn_over_fa2():
17+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
18+
19+
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward()
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
"""
2+
xformers attention implementation for packing
3+
"""
4+
5+
from typing import Optional
6+
7+
import torch
8+
import xformers
9+
import xformers.ops.fmha
10+
from transformers.modeling_flash_attention_utils import (
11+
_upad_input,
12+
)
13+
14+
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
15+
16+
xformers_attention = xformers.ops.fmha.memory_efficient_attention
17+
18+
19+
def xformers_attention_forward(
20+
module: torch.nn.Module,
21+
query: torch.Tensor,
22+
key: torch.Tensor,
23+
value: torch.Tensor,
24+
attention_mask: Optional[torch.Tensor] = None,
25+
position_ids: Optional[torch.LongTensor] = None,
26+
dropout: float = 0.0, # pylint: disable=unused-argument
27+
scaling: Optional[float] = None, # pylint: disable=unused-argument
28+
sliding_window: Optional[int] = None, # pylint: disable=unused-argument
29+
softcap: Optional[float] = None, # pylint: disable=unused-argument
30+
cu_seq_lens_q: Optional[torch.LongTensor] = None,
31+
cu_seq_lens_k: Optional[torch.LongTensor] = None,
32+
max_length_q: Optional[int] = None,
33+
max_length_k: Optional[int] = None, # pylint: disable=unused-argument
34+
**kwargs, # pylint: disable=unused-argument
35+
):
36+
# Get dimensions
37+
# query: [batch, heads, seq_len, hidden_dim]
38+
batch_size = query.size(0)
39+
query_length = query.shape[2]
40+
key_length = key.shape[2]
41+
42+
# Default causal mask
43+
attn_bias = xformers.ops.LowerTriangularMask()
44+
45+
# Check if we have sliding window attention
46+
has_sliding_window = sliding_window is not None and sliding_window < query_length
47+
48+
# Transpose dimensions for xformers (Q: [b, h, s, d] -> [b, s, h, d])
49+
query = query.transpose(1, 2)
50+
key = key.transpose(1, 2)
51+
value = value.transpose(1, 2)
52+
53+
# Get GQA parameters
54+
num_attention_heads = module.config.num_attention_heads
55+
num_key_value_heads = module.config.num_key_value_heads
56+
head_dim = query.size(-1)
57+
is_gqa = num_attention_heads != num_key_value_heads
58+
n_groups = num_attention_heads // num_key_value_heads if is_gqa else 1
59+
60+
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
61+
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
62+
# Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
63+
if position_ids is not None and (
64+
max_length_q is not None
65+
or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())
66+
):
67+
if cu_seq_lens_q is None or cu_seq_lens_k is None:
68+
cu_seq_lens_q = get_cu_seqlens_from_pos_ids(position_ids)[0]
69+
cu_seq_lens_q = cu_seq_lens_q.squeeze()
70+
seq_lengths = cu_seq_lens_q[1:] - cu_seq_lens_q[:-1]
71+
attn_bias = (
72+
xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
73+
q_seqlen=seq_lengths.tolist(),
74+
)
75+
)
76+
else:
77+
query = query.reshape(-1, query.size(-2), query.size(-1))
78+
key = key.reshape(-1, key.size(-2), key.size(-1))
79+
value = value.reshape(-1, value.size(-2), value.size(-1))
80+
81+
# Handle GQA
82+
if is_gqa:
83+
key = key.repeat_interleave(n_groups, dim=2)
84+
value = value.repeat_interleave(n_groups, dim=2)
85+
86+
elif attention_mask is not None:
87+
query, key, value, _, cu_seq_lens, _ = _upad_input(
88+
query, key, value, attention_mask, query_length
89+
)
90+
cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens
91+
seq_lengths = []
92+
for i in range(len(cu_seq_lens_q) - 1):
93+
seq_lengths.append(cu_seq_lens_q[i + 1] - cu_seq_lens_q[i])
94+
attn_bias = xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
95+
q_seqlen=seq_lengths,
96+
kv_seqlen=seq_lengths,
97+
)
98+
99+
# Handle GQA
100+
if is_gqa:
101+
key = key.repeat_interleave(n_groups, dim=2)
102+
value = value.repeat_interleave(n_groups, dim=2)
103+
else:
104+
# Handle Group Query Attention (GQA) using view/expand approach from reference
105+
key = key.view(batch_size, key_length, num_key_value_heads, 1, head_dim)
106+
value = value.view(batch_size, key_length, num_key_value_heads, 1, head_dim)
107+
key = key.expand(
108+
batch_size, key_length, num_key_value_heads, n_groups, head_dim
109+
)
110+
value = value.expand(
111+
batch_size, key_length, num_key_value_heads, n_groups, head_dim
112+
)
113+
114+
if module.training:
115+
key = key.reshape(batch_size, key_length, num_attention_heads, head_dim)
116+
value = value.reshape(batch_size, key_length, num_attention_heads, head_dim)
117+
118+
if has_sliding_window:
119+
query = query.view(
120+
1, batch_size * query_length, num_attention_heads, head_dim
121+
)
122+
key = key.view(
123+
1, batch_size * key_length, num_attention_heads, head_dim
124+
)
125+
value = value.view(
126+
1, batch_size * key_length, num_attention_heads, head_dim
127+
)
128+
else:
129+
query = query.view(
130+
batch_size, query_length, num_key_value_heads, n_groups, head_dim
131+
)
132+
133+
# If we need a sliding window attention
134+
if has_sliding_window:
135+
query = query.view(
136+
1,
137+
batch_size * query_length,
138+
num_key_value_heads,
139+
n_groups,
140+
head_dim,
141+
)
142+
key = key.view(
143+
1, batch_size * key_length, num_key_value_heads, n_groups, head_dim
144+
)
145+
value = value.view(
146+
1, batch_size * key_length, num_key_value_heads, n_groups, head_dim
147+
)
148+
149+
# Run the xformers attention
150+
attn_output = xformers_attention(
151+
query,
152+
key,
153+
value,
154+
attn_bias=attn_bias,
155+
)
156+
157+
attn_output = attn_output.view(
158+
batch_size, -1, attn_output.size(-2), attn_output.size(-1)
159+
)
160+
return attn_output, None

src/axolotl/utils/callbacks/__init__.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -868,3 +868,28 @@ def on_epoch_end(
868868
):
869869
torch.cuda.empty_cache()
870870
gc.collect()
871+
872+
873+
def colab_inference_post_train_callback(trainer: Trainer):
874+
class ColabCallback(TrainerCallback):
875+
"""Callback to prep model for inference on Google Colab"""
876+
877+
def __init__(self, cfg):
878+
self.gpu_name = torch.cuda.get_device_name(0)
879+
self.cfg = cfg
880+
881+
def on_train_end(
882+
self, args, state, control, **kwargs
883+
): # pylint: disable=unused-argument
884+
"""
885+
handle T4 gpu, we need to convert attention to eager for inference
886+
"""
887+
if "Tesla T4" in self.gpu_name and self.cfg.xformers_attention:
888+
trainer.model.config._attn_implementation = ( # pylint: disable=protected-access
889+
"eager"
890+
)
891+
trainer.model.gradient_checkpointing_disable()
892+
trainer.model.config.use_cache = True
893+
trainer.model.eval()
894+
895+
return ColabCallback

src/axolotl/utils/config/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ def resolve_dtype(cfg):
7070
if cfg.fp16 is None and not cfg.float16:
7171
cfg.fp16 = True
7272

73+
if cfg.fp16 and cfg.bf16 == "auto":
74+
cfg.bf16 = False
75+
7376
if cfg.device == "mps":
7477
cfg.load_in_8bit = False
7578
cfg.tf32 = False

src/axolotl/utils/models.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,11 @@ def __init__(
556556
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
557557

558558
def apply_patches(self) -> None:
559+
if self.cfg.xformers_attention and self.cfg.sample_packing:
560+
from axolotl.monkeypatch.attention import patch_xformers_attn_over_fa2
561+
562+
patch_xformers_attn_over_fa2()
563+
self.cfg.flash_attention = True
559564
if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2":
560565
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp_utils
561566

src/axolotl/utils/schemas/config.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -435,16 +435,6 @@ def check_gptq_w_revision(cls, data):
435435
)
436436
return data
437437

438-
@model_validator(mode="before")
439-
@classmethod
440-
def check_sample_packing_w_xformers(cls, data):
441-
if data.get("sample_packing") and data.get("xformers_attention"):
442-
raise ValueError(
443-
"sample_packing not compatible with xformers_attention. Use flash_attention"
444-
)
445-
446-
return data
447-
448438
@model_validator(mode="before")
449439
@classmethod
450440
# pylint: disable=duplicate-code
@@ -471,9 +461,10 @@ def check_sample_packing_wo_flash(cls, data):
471461
and not data.get("flash_attention")
472462
and not data.get("sdp_attention")
473463
and not data.get("flex_attention")
464+
and not data.get("xformers_attention")
474465
):
475466
LOG.warning(
476-
"sample_packing without flash, sdp or flex attention does not handle cross sample decontamination."
467+
"sample_packing without flash, sdp, xformers or flex attention does not handle cross sample decontamination."
477468
)
478469

479470
return data

0 commit comments

Comments
 (0)