Skip to content

Commit eb45843

Browse files
committed
Introduce FlashAttention3 support
1 parent 9bfd5e2 commit eb45843

File tree

22 files changed

+794
-65
lines changed

22 files changed

+794
-65
lines changed

src/fairseq2/datasets/_batch.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from typing_extensions import override
1414

1515
from fairseq2.device import Device, SupportsDeviceTransfer
16+
from fairseq2.error import InvalidOperationError
1617
from fairseq2.nn import BatchLayout
1718

1819

@@ -105,7 +106,7 @@ def __init__(
105106
self._padding = 0
106107

107108
for idx, seq_len in enumerate(seq_lens):
108-
if seq_len < 0:
109+
if seq_len < 1:
109110
raise ValueError(
110111
f"All lengths in `seq_lens` must be greater than or equal to 1, but the length at index {idx} is {seq_len} instead."
111112
)
@@ -163,11 +164,27 @@ def as_auto_regressive(self) -> tuple[SequenceBatch, SequenceBatch]:
163164

164165
seq_lens = self._seq_lens.copy()
165166

166-
seq_lens[-1] -= 1
167+
if seq_lens[-1] == 1:
168+
if len(seq_lens) == 1:
169+
raise InvalidOperationError(
170+
"The length of the sequence at index 0 is already 1 and cannot be trimmed to 0."
171+
)
172+
173+
del seq_lens[-1]
174+
else:
175+
seq_lens[-1] -= 1
167176
else:
168177
seqs = self._seqs[:, :-1]
169178

170-
seq_lens = [seq_len - 1 for seq_len in self._seq_lens]
179+
seq_lens = []
180+
181+
for idx, seq_len in enumerate(self._seq_lens):
182+
if seq_len == 1:
183+
raise InvalidOperationError(
184+
f"The length of the sequence at index {idx} is already 1 and cannot be trimmed to 0."
185+
)
186+
187+
seq_lens.append(seq_len - 1)
171188

172189
batch = SequenceBatch(
173190
seqs, seq_lens, packed=self._packed, example=self._example
@@ -487,11 +504,27 @@ def as_auto_regressive(self) -> tuple[Seq2SeqBatch, SequenceBatch]:
487504

488505
seq_lens = self._target_seq_lens.copy()
489506

490-
seq_lens[-1] -= 1
507+
if seq_lens[-1] == 1:
508+
if len(seq_lens) == 1:
509+
raise InvalidOperationError(
510+
"The length of the target sequence at index 0 is already 1 and cannot be trimmed to 0."
511+
)
512+
513+
del seq_lens[-1]
514+
else:
515+
seq_lens[-1] -= 1
491516
else:
492517
seqs = self._target_seqs[:, :-1]
493518

494-
seq_lens = [seq_len - 1 for seq_len in self._target_seq_lens]
519+
seq_lens = []
520+
521+
for idx, seq_len in enumerate(self._target_seq_lens):
522+
if seq_len == 1:
523+
raise InvalidOperationError(
524+
f"The length of the target sequence at index {idx} is already 1 and cannot be trimmed to 0."
525+
)
526+
527+
seq_lens.append(seq_len - 1)
495528

496529
batch = Seq2SeqBatch(
497530
self._source_seqs,

src/fairseq2/datasets/instruction.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,9 @@ def skip(example: dict[str, Any]) -> bool:
297297
"target_mask", pad_value=False
298298
)
299299

300-
collater = Collater(pad_value=0, overrides=[target_mask_collate_opts])
300+
collater = Collater(
301+
pad_value=tokenizer.vocab_info.pad_idx, overrides=[target_mask_collate_opts]
302+
)
301303

302304
builder.map(collater, num_parallel_calls=npc)
303305

src/fairseq2/metrics/recorders/_tensorboard.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,13 @@
1313

1414
from typing_extensions import override
1515

16+
try:
17+
from torch.utils.tensorboard import SummaryWriter # type: ignore[attr-defined]
18+
except ImportError:
19+
_has_tensorboard = False
20+
else:
21+
_has_tensorboard = True
22+
1623
from fairseq2.logging import log
1724
from fairseq2.metrics import MetricDescriptor
1825
from fairseq2.registry import Provider
@@ -28,13 +35,6 @@
2835
NoopMetricRecorder,
2936
)
3037

31-
try:
32-
from torch.utils.tensorboard import SummaryWriter # type: ignore[attr-defined]
33-
except ImportError:
34-
has_tensorboard = False
35-
else:
36-
has_tensorboard = True
37-
3838

3939
@final
4040
class TensorBoardRecorder(MetricRecorder):
@@ -51,7 +51,7 @@ def __init__(
5151
:param output_dir:
5252
The base directory under which to store the TensorBoard files.
5353
"""
54-
if not has_tensorboard:
54+
if not _has_tensorboard:
5555
log.warning("tensorboard not found. Please install it with `pip install tensorboard`.") # fmt: skip
5656

5757
self._output_dir = output_dir
@@ -94,7 +94,7 @@ def record_metrics(
9494
) from ex
9595

9696
def _get_writer(self, run: str) -> SummaryWriter | None:
97-
if not has_tensorboard:
97+
if not _has_tensorboard:
9898
return None
9999

100100
writer = self._writers.get(run)

src/fairseq2/metrics/recorders/_wandb.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,13 @@
1313

1414
from typing_extensions import override
1515

16+
try:
17+
import wandb # type: ignore[import-not-found]
18+
except ImportError:
19+
_has_wandb = False
20+
else:
21+
_has_wandb = True
22+
1623
from fairseq2.logging import log
1724
from fairseq2.metrics import MetricDescriptor
1825
from fairseq2.registry import Provider
@@ -28,13 +35,6 @@
2835
NoopMetricRecorder,
2936
)
3037

31-
try:
32-
import wandb # type: ignore[import-not-found]
33-
except ImportError:
34-
has_wandb = False
35-
else:
36-
has_wandb = True
37-
3838

3939
@final
4040
class WandbRecorder(MetricRecorder):
@@ -57,7 +57,7 @@ def __init__(
5757
In order to use W&B, run `wandb login` from the command line and enter
5858
the API key when prompted.
5959
"""
60-
if not has_wandb:
60+
if not _has_wandb:
6161
log.warning("wandb not found. Please install it with `pip install wandb`.") # fmt: skip
6262

6363
self._run = None

src/fairseq2/models/llama/_factory.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,9 @@ def init_embed(embed: StandardEmbedding) -> None:
9393
self._init_truncated_normal(embed.weight, bias=None, std=std)
9494

9595
return StandardEmbedding(
96-
num_embeddings=config.vocab_size,
97-
embedding_dim=config.model_dim,
96+
config.vocab_size,
97+
config.model_dim,
98+
pad_idx=config.pad_idx,
9899
init_fn=init_embed,
99100
)
100101

src/fairseq2/models/mistral/_factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def create_embedding(self) -> Embedding:
8484
config = self._config
8585

8686
return StandardEmbedding(
87-
num_embeddings=config.vocab_size, embedding_dim=config.model_dim
87+
config.vocab_size, config.model_dim, pad_idx=config.pad_idx
8888
)
8989

9090
def create_decoder(self) -> TransformerLMDecoder:

src/fairseq2/models/s2t_transformer/_factory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,8 @@ def create_target_embedding(self) -> Embedding:
232232
config = self._config
233233

234234
return StandardEmbedding(
235-
num_embeddings=config.target_vocab_size,
236-
embedding_dim=config.model_dim,
235+
config.target_vocab_size,
236+
config.model_dim,
237237
pad_idx=config.pad_idx,
238238
init_fn=init_scaled_embedding,
239239
)

src/fairseq2/models/transformer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,8 @@
142142
from fairseq2.models.transformer._sdpa._default import (
143143
set_default_sdpa_factory as set_default_sdpa_factory,
144144
)
145+
from fairseq2.models.transformer._sdpa._flash2 import Flash2SDPA as Flash2SDPA
146+
from fairseq2.models.transformer._sdpa._flash3 import Flash3SDPA as Flash3SDPA
145147
from fairseq2.models.transformer._sdpa._naive import NaiveSDPA as NaiveSDPA
146148
from fairseq2.models.transformer._sdpa._naive import (
147149
naive_scaled_dot_product_attention as naive_scaled_dot_product_attention,

src/fairseq2/models/transformer/_factory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ def create_embedding(self) -> Embedding:
9696
config = self._config
9797

9898
return StandardEmbedding(
99-
num_embeddings=config.vocab_size,
100-
embedding_dim=config.model_dim,
99+
config.vocab_size,
100+
config.model_dim,
101101
pad_idx=config.pad_idx,
102102
init_fn=init_scaled_embedding,
103103
)

src/fairseq2/models/transformer/_sdpa/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)