Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 32 additions & 7 deletions bionemo-recipes/recipes/llama3_native_te/data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from typing import Any

import torch
from genomic_masking_functions import make_upper_case
from genomic_masking_functions import Evo2MaskingConstants, make_upper_case, mask_phylogenetic_tags


logger = logging.getLogger(__name__)
Expand All @@ -45,6 +45,7 @@ class GenomicDataCollator:
base_collator: The underlying collator (e.g., DataCollatorForLanguageModeling)
uppercase_labels: Whether to uppercase labels. Default: False.
mask_degenerate_bases: Whether to mask non-ACGT bases. Default: True.
mask_phylo_tags: Whether to mask phylogenetic tags. Default: False (Milestone 2).
dna_tokens: Tuple of valid DNA token IDs (A, C, G, T upper+lowercase)
control_tags: Tuple of control character token IDs (@, #)

Expand All @@ -55,27 +56,47 @@ class GenomicDataCollator:
... base_collator=base,
... uppercase_labels=False,
... mask_degenerate_bases=True,
... mask_phylo_tags=False,
... )
"""

base_collator: Any
uppercase_labels: bool = False
mask_degenerate_bases: bool = True
mask_phylo_tags: bool = False
dna_tokens: tuple[int, ...] = (65, 67, 71, 84, 97, 99, 103, 116) # A, C, G, T (upper+lower)
control_tags: tuple[int, ...] = (64, 35) # '@', '#'

def __call__(self, features: list) -> dict[str, Any]:
"""Apply base collator, then add genomic masking."""
"""Apply base collator, then add genomic masking.

Order of operations (IMPORTANT):
1. Mask phylogenetic tags FIRST (needs pipes and lowercase to detect!)
2. Mask degenerate bases (simple character check)
3. Uppercase labels (after detection, since phylo relies on case)
"""
# Base collator handles batching and CLM label creation
batch = self.base_collator(features)

labels = batch["labels"]

# Step 1: Uppercase labels (inputs stay mixed case)
if self.uppercase_labels:
labels, _ = make_upper_case(labels)

# Step 2: Mask degenerate bases and control characters
# Step 1: Mask phylogenetic tags FIRST (BEFORE degenerate!)
# Phylo detection needs:
# - Pipes (|) to detect boundaries
# - Lowercase letters to identify tags
# - Must run before degenerate masking which would mask pipes!
if self.mask_phylo_tags:
phylo_mask = mask_phylogenetic_tags(
tokenized_sequence=labels,
terminal_tag_char=Evo2MaskingConstants.TAG_BOUNDS,
other_tag_chars=Evo2MaskingConstants.TAG_CHARS,
eod_token_id=Evo2MaskingConstants.DEFAULT_EOD,
max_tag_len=Evo2MaskingConstants.MAX_TAG_LEN,
)
# Where mask is 0, set label to -100 (but preserve existing -100)
labels[(phylo_mask == 0) & (labels != -100)] = -100

# Step 2: Mask degenerate bases and control characters (AFTER phylo!)
if self.mask_degenerate_bases:
dna_tokens_tensor = torch.tensor(self.dna_tokens, device=labels.device)
control_tensor = torch.tensor(self.control_tags, device=labels.device)
Expand All @@ -87,5 +108,9 @@ def __call__(self, features: list) -> dict[str, Any]:
# Mask both, but preserve existing -100 values
labels[(not_dna | is_control) & (labels != -100)] = -100

# Step 3: Uppercase labels (AFTER phylo detection!)
if self.uppercase_labels:
labels, _ = make_upper_case(labels)

batch["labels"] = labels
return batch
9 changes: 6 additions & 3 deletions bionemo-recipes/recipes/llama3_native_te/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def create_bshd_dataloader(
sequence_column: str = "sequence",
uppercase_labels: bool = False,
mask_degenerate_bases: bool = True,
mask_phylo_tags: bool = False,
):
"""Create a BSHD dataloader for genomic sequences using CLM (causal language modeling).

Expand All @@ -147,7 +148,8 @@ def create_bshd_dataloader(
use_stateful_dataloader: Whether to use the StatefulDataLoader to enable checkpointing the dataloader state.
sequence_column: Name of the column containing genomic sequences (default: "sequence").
uppercase_labels: Whether to uppercase labels (genomic masking). Default: False.
mask_degenerate_bases: Whether to mask non-ACGT bases (genomic masking). Default: False.
mask_degenerate_bases: Whether to mask non-ACGT bases (genomic masking). Default: True.
mask_phylo_tags: Whether to mask phylogenetic tags (genomic masking). Default: False (Milestone 2).

Returns:
A tuple of (dataloader, dataset_or_sampler).
Expand Down Expand Up @@ -180,16 +182,17 @@ def create_bshd_dataloader(
)

# Wrap with genomic collator if masking options are enabled
if uppercase_labels or mask_degenerate_bases:
if uppercase_labels or mask_degenerate_bases or mask_phylo_tags:
from data_collator import GenomicDataCollator

data_collator = GenomicDataCollator(
base_collator=base_collator,
uppercase_labels=uppercase_labels,
mask_degenerate_bases=mask_degenerate_bases,
mask_phylo_tags=mask_phylo_tags,
)
logger.info(
f"Using GenomicDataCollator (uppercase={uppercase_labels}, mask_degenerate={mask_degenerate_bases})"
f"Using GenomicDataCollator (uppercase={uppercase_labels}, mask_degenerate={mask_degenerate_bases}, mask_phylo={mask_phylo_tags})"
)
else:
# Use base collator directly for backward compatibility
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,201 @@ def make_upper_case(tokens, lowercase_start=97, lowercase_end=122, case_diff=32)
return uppercase_tensor, lowercase_mask


def mask_phylogenetic_tags( # noqa: C901
tokenized_sequence: torch.Tensor,
terminal_tag_char: int = 124, # '|' (pipe)
other_tag_chars: set[int] | None = None, # '_', ';', space
eod_token_id: int = 0,
max_tag_len: int = 2048,
) -> torch.Tensor:
"""Create a binary mask for sequences containing phylogenetic tags.

Adapted from: nemo.collections.llm.gpt.data.megatron.hyena.evo2_dataset.Evo2Dataset.mask_phylogenetic_tags

Phylogenetic tags have format: |d__Bacteria;p__Proteobacteria;c__Gammaproteobacteria|

Detection rules:
- Tags are enclosed in pipes (|)
- Contain taxonomy separators (_, ;, space)
- Start with lowercase letter after pipe (d__, p__, c__, etc.)

Args:
tokenized_sequence: Token IDs. Shape: (seq_len,) or (batch_size, seq_len)
terminal_tag_char: ASCII for '|' (default: 124)
other_tag_chars: ASCII values for tag separators (default: {95, 59, 32} = '_', ';', space)
eod_token_id: End-of-document token (default: 0)
max_tag_len: Maximum tag length (default: 2048)

Returns:
Binary mask: 1 = keep (DNA), 0 = mask (tag). Same shape as input.
"""
if other_tag_chars is None:
other_tag_chars = {95, 59, 32} # '_', ';', space

device = tokenized_sequence.device

# Handle empty sequences
if tokenized_sequence.numel() == 0:
return torch.ones(0, device=device, dtype=torch.int)

# Handle single token
if tokenized_sequence.numel() == 1:
mask = torch.ones(1, device=device, dtype=torch.int)
token = tokenized_sequence.item()
if token == terminal_tag_char or token in other_tag_chars:
mask[0] = 0
return mask

# Ensure 2D (batch, seq_len)
batched = tokenized_sequence.ndim == 2
if not batched:
tokenized_sequence = tokenized_sequence.unsqueeze(0)
batch_size, seq_len = tokenized_sequence.shape

# Valid DNA + degenerate bases + control chars
valid_dna_and_degenerate = {
45,
65,
66,
67,
68,
71,
72,
75,
77,
78,
82,
83,
84,
85,
86,
87,
89, # Uppercase
97,
98,
99,
100,
103,
104,
107,
109,
110,
114,
115,
116,
117,
118,
119,
121, # Lowercase
}
control_tags_set = {64, 35} # '@', '#'
valid_dna_or_control_tensor = torch.tensor(
list(valid_dna_and_degenerate | control_tags_set), device=device, dtype=tokenized_sequence.dtype
)

# Initialize mask to all ones (keep everything)
out_mask = torch.ones_like(tokenized_sequence, dtype=torch.int)

def region_all_valid_or_control(region: torch.Tensor) -> bool:
"""Check if all tokens in region are valid DNA or control chars."""
if region.numel() == 0:
return True
return bool(torch.all(torch.isin(region, valid_dna_or_control_tensor)).cpu().item())

def process_segment(seg_seq: torch.Tensor) -> torch.Tensor: # noqa: C901
"""Process one EOD-free segment."""
seg_len = seg_seq.size(0)
seg_mask = torch.ones(seg_len, device=device, dtype=torch.int)

# Find pipe positions
pipe_pos = (seg_seq == terminal_tag_char).nonzero(as_tuple=True)[0].cpu().tolist()

if len(pipe_pos) == 0:
# No pipes: mask if contains tag chars or invalid DNA and short enough
if seg_len < max_tag_len and not region_all_valid_or_control(seg_seq):
seg_mask.zero_()
return seg_mask

# Mask all pipe positions
seg_mask[pipe_pos] = 0

# Determine if tag starts before first pipe (state machine)
first_pipe = pipe_pos[0]
if first_pipe < seg_len - 1:
# Check token after first pipe
if seg_len > first_pipe + 2:
first_tok = seg_seq[first_pipe + 1].item()
next_tok = seg_seq[first_pipe + 2].item()
# If pattern is [char]_ or starts with ;, tag is AFTER pipe
is_tag = not (next_tok == 95 or first_tok == 59)
elif seg_len > first_pipe + 1:
next_tok = seg_seq[first_pipe + 1].item()
# Check for d, D, r, R (domain/realm) or ; (missing field)
is_tag = next_tok not in {68, 100, 82, 114, 59}
elif first_pipe >= max_tag_len or region_all_valid_or_control(seg_seq[:first_pipe]):
is_tag = False
else:
is_tag = True
else:
# Sequence ends with pipe
if first_pipe >= max_tag_len or region_all_valid_or_control(seg_seq[:first_pipe]):
return seg_mask
else:
seg_mask[:first_pipe] = 0
return seg_mask

# Process regions between pipes (state machine)
start = 0
for end in pipe_pos:
seg_region_len = end - start
if is_tag and seg_region_len < max_tag_len:
seg_mask[start:end] = 0
elif is_tag and seg_region_len >= max_tag_len:
# Too long to be a tag, must be DNA
is_tag = False
# Flip state for next region
is_tag = not is_tag
start = end + 1

# Process region after last pipe
seg_region_len = seg_len - start
if is_tag and seg_region_len < max_tag_len:
seg_mask[start:] = 0

return seg_mask

# Process each batch row, splitting on EOD tokens
for b in range(batch_size):
row = tokenized_sequence[b]
eod_positions = (row == eod_token_id).nonzero(as_tuple=True)[0].cpu().tolist()

start_idx = 0
for pos in eod_positions:
if pos > start_idx:
seg = row[start_idx:pos]
seg_mask = process_segment(seg)
out_mask[b, start_idx:pos] = seg_mask
# Leave EOD unmasked
start_idx = pos + 1

# Process remaining after last EOD
if start_idx < seq_len:
seg = row[start_idx:]
seg_mask = process_segment(seg)
out_mask[b, start_idx:] = seg_mask

# Safety: mask any non-valid DNA tokens that slipped through
out_mask[~torch.isin(tokenized_sequence, valid_dna_or_control_tensor)] = 0

# Force EOD tokens to be unmasked
out_mask[tokenized_sequence == eod_token_id] = 1

if not batched:
out_mask = out_mask.squeeze(0)

return out_mask


class Evo2MaskingConstants:
"""Constants used in Evo2 genomic sequence masking."""

Expand All @@ -54,3 +249,9 @@ class Evo2MaskingConstants:

# Control characters used in data formatting
CONTROL_TAGS: ClassVar[list[int]] = [64, 35] # '@', '#'

# Phylogenetic tag constants
TAG_BOUNDS: ClassVar[int] = 124 # '|' pipe character
TAG_CHARS: ClassVar[set[int]] = {95, 59, 32} # '_', ';', space
MAX_TAG_LEN: ClassVar[int] = 2048
DEFAULT_EOD: ClassVar[int] = 0
60 changes: 60 additions & 0 deletions bionemo-recipes/recipes/llama3_native_te/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,66 @@ def test_streaming_dataset_handles_missing_record_column(tokenizer_path, tmp_pat
assert "input_ids" in sample, "input_ids should be present"


def test_dataloader_with_phylo_masking(tokenizer_path, tmp_path):
"""Test that create_bshd_dataloader works with phylogenetic tag masking enabled.

Integration test verifying:
- GenomicDataCollator is used when phylo masking enabled
- Phylo tags are detected and masked in labels
- Batches are produced in correct BSHD format
"""
# Create test data with phylo tags
parquet_path = tmp_path / "data_with_phylo_tags.parquet"
sequences = [
"ACGT|d__Bacteria|GGTA", # Has phylo tag
"TTCC|p__Firmicutes|AACG", # Has phylo tag
]
table = pa.table({"sequence": sequences})
pq.write_table(table, parquet_path)

distributed_config = DistributedConfig(rank=0, world_size=1)

load_dataset_kwargs = {
"path": "parquet",
"data_files": str(parquet_path),
"split": "train",
}

# Create dataloader with phylo masking enabled
dataloader, _ = create_bshd_dataloader(
distributed_config=distributed_config,
tokenizer_path=tokenizer_path,
load_dataset_kwargs=load_dataset_kwargs,
micro_batch_size=2,
num_workers=0,
max_seq_length=30,
stride=10,
use_lazy_tokenization=False,
mask_phylo_tags=True, # Enable phylo masking
mask_degenerate_bases=False, # Disable for clearer test
uppercase_labels=False,
)

# Get a batch
batch = next(iter(dataloader))

# Verify BSHD format
assert batch["input_ids"].ndim == 2, "Should be BSHD format [B, S]"
assert batch["labels"].ndim == 2, "Labels should be BSHD format"

# Verify phylo tag characters are masked
labels = batch["labels"]
# Tag characters that should be masked:
assert 100 not in labels, "d (100) from phylo tags should be masked"
assert 112 not in labels, "p (112) from phylo tags should be masked"
assert 95 not in labels, "_ (95) from phylo tags should be masked"
assert 124 not in labels, "| (124) pipes should be masked"

# Verify valid DNA tokens are present
valid_dna = [65, 67, 71, 84] # A, C, G, T
assert any(tok in labels for tok in valid_dna), "Should have valid DNA tokens"


def test_dataloader_with_genomic_masking(tokenizer_path, tmp_path):
"""Test that create_bshd_dataloader works with genomic masking enabled.

Expand Down
Loading