Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions bionemo-recipes/models/esm2/src/esm/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,73 @@ def set_epoch(self, epoch: int):


@dataclass
class SequencePackingIterableDataset(torch.utils.data.IterableDataset):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we just wrap the existing TokenPackingDataset to do this?

set max_tokens_per_batch to be the desired sequence length, set split_samples=True; then before you return the sample just concatenate along the sequence dimension

"""Dataset that packs sequences by concatenating across boundaries for BSHD format.

This dataset creates fixed-length samples by arbitrarily splitting across sequence boundaries.

Example:
Input sequences: AAA, BBBBB, CCC, DDDDDDD
Output samples: AAAB, BBBB, CCCD, DDDD, DD...

Unlike TokenPackingDataset (which yields variable-count lists for THD format with cu_seqlens),
this yields fixed-length samples for BSHD format with pure causal masking.

Args:
dataset: The input IterableDataset (can be windowed or not).
max_seq_length: Fixed length for each output sample.
pad_token_id: Token ID for padding (only used if drop_last=False).
drop_last: Whether to drop the last incomplete sample. Default: True.
"""

dataset: datasets.IterableDataset
max_seq_length: int
pad_token_id: int
drop_last: bool = True

def __iter__(self):
"""Yield fixed-length samples by streaming tokens across sequence boundaries."""
buffer_input_ids = []
buffer_labels = []
has_labels = None

for sample in iter(self.dataset):
# Determine if dataset has labels on first sample
if has_labels is None:
has_labels = "labels" in sample

# Add tokens to buffer
buffer_input_ids.extend(sample["input_ids"])
if has_labels:
buffer_labels.extend(sample.get("labels", sample["input_ids"]))

# Yield full chunks of max_seq_length
while len(buffer_input_ids) >= self.max_seq_length:
output = {"input_ids": buffer_input_ids[: self.max_seq_length]}
if has_labels:
output["labels"] = buffer_labels[: self.max_seq_length]

yield output

# Keep remainder in buffer
buffer_input_ids = buffer_input_ids[self.max_seq_length :]
if has_labels:
buffer_labels = buffer_labels[self.max_seq_length :]

# Handle remaining tokens (only if not dropping last)
if buffer_input_ids and not self.drop_last:
padding_length = self.max_seq_length - len(buffer_input_ids)
output = {"input_ids": buffer_input_ids + [self.pad_token_id] * padding_length}
if has_labels:
output["labels"] = buffer_labels + [-100] * padding_length
yield output

def set_epoch(self, epoch: int):
"""Set the epoch for the underlying dataset."""
if hasattr(self.dataset, "set_epoch"):
self.dataset.set_epoch(epoch)


class DataCollatorForContextParallel:
"""A collator that is aware of context parallelism.

Expand Down
67 changes: 67 additions & 0 deletions bionemo-recipes/models/llama3/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,73 @@ def set_epoch(self, epoch: int):


@dataclass
class SequencePackingIterableDataset(torch.utils.data.IterableDataset):
"""Dataset that packs sequences by concatenating across boundaries for BSHD format.

This dataset creates fixed-length samples by arbitrarily splitting across sequence boundaries.

Example:
Input sequences: AAA, BBBBB, CCC, DDDDDDD
Output samples: AAAB, BBBB, CCCD, DDDD, DD...

Unlike TokenPackingDataset (which yields variable-count lists for THD format with cu_seqlens),
this yields fixed-length samples for BSHD format with pure causal masking.

Args:
dataset: The input IterableDataset (can be windowed or not).
max_seq_length: Fixed length for each output sample.
pad_token_id: Token ID for padding (only used if drop_last=False).
drop_last: Whether to drop the last incomplete sample. Default: True.
"""

dataset: datasets.IterableDataset
max_seq_length: int
pad_token_id: int
drop_last: bool = True

def __iter__(self):
"""Yield fixed-length samples by streaming tokens across sequence boundaries."""
buffer_input_ids = []
buffer_labels = []
has_labels = None

for sample in iter(self.dataset):
# Determine if dataset has labels on first sample
if has_labels is None:
has_labels = "labels" in sample

# Add tokens to buffer
buffer_input_ids.extend(sample["input_ids"])
if has_labels:
buffer_labels.extend(sample.get("labels", sample["input_ids"]))

# Yield full chunks of max_seq_length
while len(buffer_input_ids) >= self.max_seq_length:
output = {"input_ids": buffer_input_ids[: self.max_seq_length]}
if has_labels:
output["labels"] = buffer_labels[: self.max_seq_length]

yield output

# Keep remainder in buffer
buffer_input_ids = buffer_input_ids[self.max_seq_length :]
if has_labels:
buffer_labels = buffer_labels[self.max_seq_length :]

# Handle remaining tokens (only if not dropping last)
if buffer_input_ids and not self.drop_last:
padding_length = self.max_seq_length - len(buffer_input_ids)
output = {"input_ids": buffer_input_ids + [self.pad_token_id] * padding_length}
if has_labels:
output["labels"] = buffer_labels + [-100] * padding_length
yield output

def set_epoch(self, epoch: int):
"""Set the epoch for the underlying dataset."""
if hasattr(self.dataset, "set_epoch"):
self.dataset.set_epoch(epoch)


class DataCollatorForContextParallel:
"""A collator that is aware of context parallelism.

Expand Down
38 changes: 25 additions & 13 deletions bionemo-recipes/models/llama3/modeling_llama_te.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

from collections import OrderedDict
from contextlib import nullcontext
from typing import Unpack

import torch
Expand Down Expand Up @@ -44,6 +45,8 @@ class NVLlamaConfig(LlamaConfig):

attn_input_format: str = "thd"
self_attn_mask_type: str = "padding_causal"
fp8_first_last_bf16: bool = False
"""When True, keeps first and last transformer layers in bf16 for FP8 numerical stability."""


class NVLlamaPreTrainedModel(PreTrainedModel):
Expand Down Expand Up @@ -221,24 +224,33 @@ def forward(
with torch.autocast(device_type="cuda", enabled=False):
te_rope_emb = self.rotary_emb(max_seq_len=self.config.max_position_embeddings)

for decoder_layer in self.layers[: self.config.num_hidden_layers]:
num_layers = self.config.num_hidden_layers
for layer_idx, decoder_layer in enumerate(self.layers[:num_layers]):
if output_hidden_states:
all_hidden_states = (*all_hidden_states, hidden_states)

hidden_states = decoder_layer(
hidden_states,
attention_mask=None if self.config.attn_input_format == "thd" else attention_mask,
rotary_pos_emb=te_rope_emb,
inference_params=past_key_values,
cu_seqlens_q=kwargs.get("cu_seq_lens_q", None),
cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None),
cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None),
cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None),
max_seqlen_q=kwargs.get("max_length_q", None),
max_seqlen_kv=kwargs.get("max_length_k", None),
pad_between_seqs=kwargs.get("pad_between_seqs", None),
# Optionally keep first and last layers in bf16 for FP8 numerical stability
use_bf16_for_layer = getattr(self.config, "fp8_first_last_bf16", False) and (
layer_idx == 0 or layer_idx == num_layers - 1
)

# If fp8_first_last_bf16 is enabled, disable FP8 for first/last layers
# This nested fp8_autocast will override the outer one from training script
with transformer_engine.pytorch.fp8_autocast(enabled=False) if use_bf16_for_layer else nullcontext():
hidden_states = decoder_layer(
hidden_states,
attention_mask=None if self.config.attn_input_format == "thd" else attention_mask,
rotary_pos_emb=te_rope_emb,
inference_params=past_key_values,
cu_seqlens_q=kwargs.get("cu_seq_lens_q", None),
cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None),
cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None),
cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None),
max_seqlen_q=kwargs.get("max_length_q", None),
max_seqlen_kv=kwargs.get("max_length_k", None),
pad_between_seqs=kwargs.get("pad_between_seqs", None),
)

hidden_states = self.norm(hidden_states)

# add hidden states from the last decoder layer. Note that these will be in THD format; we could possibly pad
Expand Down
67 changes: 67 additions & 0 deletions bionemo-recipes/recipes/esm2_native_te/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,73 @@ def set_epoch(self, epoch: int):


@dataclass
class SequencePackingIterableDataset(torch.utils.data.IterableDataset):
"""Dataset that packs sequences by concatenating across boundaries for BSHD format.

This dataset creates fixed-length samples by arbitrarily splitting across sequence boundaries.

Example:
Input sequences: AAA, BBBBB, CCC, DDDDDDD
Output samples: AAAB, BBBB, CCCD, DDDD, DD...

Unlike TokenPackingDataset (which yields variable-count lists for THD format with cu_seqlens),
this yields fixed-length samples for BSHD format with pure causal masking.

Args:
dataset: The input IterableDataset (can be windowed or not).
max_seq_length: Fixed length for each output sample.
pad_token_id: Token ID for padding (only used if drop_last=False).
drop_last: Whether to drop the last incomplete sample. Default: True.
"""

dataset: datasets.IterableDataset
max_seq_length: int
pad_token_id: int
drop_last: bool = True

def __iter__(self):
"""Yield fixed-length samples by streaming tokens across sequence boundaries."""
buffer_input_ids = []
buffer_labels = []
has_labels = None

for sample in iter(self.dataset):
# Determine if dataset has labels on first sample
if has_labels is None:
has_labels = "labels" in sample

# Add tokens to buffer
buffer_input_ids.extend(sample["input_ids"])
if has_labels:
buffer_labels.extend(sample.get("labels", sample["input_ids"]))

# Yield full chunks of max_seq_length
while len(buffer_input_ids) >= self.max_seq_length:
output = {"input_ids": buffer_input_ids[: self.max_seq_length]}
if has_labels:
output["labels"] = buffer_labels[: self.max_seq_length]

yield output

# Keep remainder in buffer
buffer_input_ids = buffer_input_ids[self.max_seq_length :]
if has_labels:
buffer_labels = buffer_labels[self.max_seq_length :]

# Handle remaining tokens (only if not dropping last)
if buffer_input_ids and not self.drop_last:
padding_length = self.max_seq_length - len(buffer_input_ids)
output = {"input_ids": buffer_input_ids + [self.pad_token_id] * padding_length}
if has_labels:
output["labels"] = buffer_labels + [-100] * padding_length
yield output

def set_epoch(self, epoch: int):
"""Set the epoch for the underlying dataset."""
if hasattr(self.dataset, "set_epoch"):
self.dataset.set_epoch(epoch)


class DataCollatorForContextParallel:
"""A collator that is aware of context parallelism.

Expand Down
67 changes: 67 additions & 0 deletions bionemo-recipes/recipes/llama3_native_te/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,73 @@ def set_epoch(self, epoch: int):


@dataclass
class SequencePackingIterableDataset(torch.utils.data.IterableDataset):
"""Dataset that packs sequences by concatenating across boundaries for BSHD format.

This dataset creates fixed-length samples by arbitrarily splitting across sequence boundaries.

Example:
Input sequences: AAA, BBBBB, CCC, DDDDDDD
Output samples: AAAB, BBBB, CCCD, DDDD, DD...

Unlike TokenPackingDataset (which yields variable-count lists for THD format with cu_seqlens),
this yields fixed-length samples for BSHD format with pure causal masking.

Args:
dataset: The input IterableDataset (can be windowed or not).
max_seq_length: Fixed length for each output sample.
pad_token_id: Token ID for padding (only used if drop_last=False).
drop_last: Whether to drop the last incomplete sample. Default: True.
"""

dataset: datasets.IterableDataset
max_seq_length: int
pad_token_id: int
drop_last: bool = True

def __iter__(self):
"""Yield fixed-length samples by streaming tokens across sequence boundaries."""
buffer_input_ids = []
buffer_labels = []
has_labels = None

for sample in iter(self.dataset):
# Determine if dataset has labels on first sample
if has_labels is None:
has_labels = "labels" in sample

# Add tokens to buffer
buffer_input_ids.extend(sample["input_ids"])
if has_labels:
buffer_labels.extend(sample.get("labels", sample["input_ids"]))

# Yield full chunks of max_seq_length
while len(buffer_input_ids) >= self.max_seq_length:
output = {"input_ids": buffer_input_ids[: self.max_seq_length]}
if has_labels:
output["labels"] = buffer_labels[: self.max_seq_length]

yield output

# Keep remainder in buffer
buffer_input_ids = buffer_input_ids[self.max_seq_length :]
if has_labels:
buffer_labels = buffer_labels[self.max_seq_length :]

# Handle remaining tokens (only if not dropping last)
if buffer_input_ids and not self.drop_last:
padding_length = self.max_seq_length - len(buffer_input_ids)
output = {"input_ids": buffer_input_ids + [self.pad_token_id] * padding_length}
if has_labels:
output["labels"] = buffer_labels + [-100] * padding_length
yield output

def set_epoch(self, epoch: int):
"""Set the epoch for the underlying dataset."""
if hasattr(self.dataset, "set_epoch"):
self.dataset.set_epoch(epoch)


class DataCollatorForContextParallel:
"""A collator that is aware of context parallelism.

Expand Down
Loading