Skip to content

Commit 76580e3

Browse files
committed
misc collator, test, and state lints
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent e740697 commit 76580e3

File tree

17 files changed

+305
-152
lines changed

17 files changed

+305
-152
lines changed

bionemo-recipes/models/amplify/src/amplify/state.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ def apply_transforms(
6767
source: Union[nn.Module, _ModelState],
6868
target: TargetModuleT,
6969
mapping: Dict[str, str],
70-
transforms: Optional[List[Callable[[TransformCTX], TransformCTX]]] = [],
71-
state_dict_ignored_entries: List = [],
70+
transforms: Optional[List[Callable[[TransformCTX], TransformCTX]]] = None,
71+
state_dict_ignored_entries: Optional[List] = None,
7272
cast_dtype: Optional[torch.dtype] = None,
7373
) -> TargetModuleT:
7474
"""Transform the state dictionary of a source module to match the structure of a target module's state dictionary.
@@ -126,6 +126,11 @@ def scale_weights(ctx):
126126
This function is particularly useful when adapting models from different frameworks or
127127
when consolidating models with different architectural changes.
128128
"""
129+
if transforms is None:
130+
transforms = []
131+
if state_dict_ignored_entries is None:
132+
state_dict_ignored_entries = []
133+
129134
# Track dtypes to make sure they weren't modified during conversion.
130135
target_orig_dtypes = extract_dtypes(target.named_parameters())
131136

@@ -318,7 +323,7 @@ def __call__(self, ctx: TransformCTX) -> TransformCTX:
318323
try:
319324
source_match = source_matches[target_index]
320325
except IndexError as e:
321-
logger.error(f"Enountered IndexError during transform.\n{source_matches=}\n{target_matches=}")
326+
logger.error(f"Encountered IndexError during transform.\n{source_matches=}\n{target_matches=}")
322327
raise e
323328
if accepts_var_args:
324329
source_values = [source_dict[k] for k in source_match]

bionemo-recipes/models/esm2/src/esm/collator.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,11 @@ def __call__(self, features, return_tensors=None):
156156
sequence processing capabilities. When pad_to_multiple_of is used, an additional
157157
mock sequence is appended to reach the desired total length.
158158
"""
159+
if return_tensors is not None and return_tensors != "pt":
160+
raise NotImplementedError(f"Only return_tensors='pt' is supported, got '{return_tensors}'")
161+
159162
# Perform the masking with the BSHD collator.
160-
bshd_batch = self.collator(features)
163+
bshd_batch = self.collator(features, return_tensors=return_tensors)
161164

162165
# Create the flattened batch to get the cu_seq_lens_q and cu_seq_lens_k values.
163166
packed_batch = _pt_flatten_collate(features, return_position_ids=self.return_position_ids)
@@ -279,33 +282,48 @@ def __iter__(self):
279282
samples = []
280283
current_length = 0
281284
for sample in iter(self.dataset):
282-
current_length += self._padded_len(len(sample["input_ids"]))
285+
sample_length = len(sample["input_ids"])
286+
if sample_length > self.max_tokens_per_batch:
287+
raise ValueError(
288+
f"TokenPackingDataset: Sample length ({sample_length}) exceeds max_tokens_per_batch "
289+
f"({self.max_tokens_per_batch}). Set truncation or a maximum length in your tokenizer or dataset to"
290+
"ensure all samples fit within max_tokens_per_batch."
291+
)
292+
293+
current_length += self._padded_len(sample_length)
283294
if current_length == self.max_tokens_per_batch:
284295
yield [*samples, sample]
285296
samples = []
286297
current_length = 0
287298

288299
elif current_length > self.max_tokens_per_batch:
289300
if not self.split_samples:
290-
# If we are not splitting samples, we can just yield the current batch (before this sample) and
291-
# start a new one.
292-
yield samples
301+
# Yield the current batch (before this sample) and start a new one with this sample.
302+
if samples:
303+
yield samples
293304
samples = [sample]
294-
305+
current_length = self._padded_len(sample_length)
295306
else:
296-
# Calculate how many padded tokens are already in the batch
297-
tokens_in_batch = current_length - self._padded_len(len(sample["input_ids"]))
307+
# Calculate how many padded tokens are already in the batch.
308+
tokens_in_batch = current_length - self._padded_len(sample_length)
298309
# Calculate how many tokens we can fit from this sample, ensuring the
299310
# padded length doesn't exceed the remaining capacity.
300311
tokens_available = self.max_tokens_per_batch - tokens_in_batch
301312
if self.pad_sequences_to_be_divisible_by is not None:
302313
d = self.pad_sequences_to_be_divisible_by
303314
tokens_available = (tokens_available // d) * d
304-
first_part, remaining_part = _split_sample_by_num_tokens(sample, tokens_available)
305-
yield [*samples, first_part]
306-
samples = [remaining_part]
307-
308-
current_length = self._padded_len(len(samples[0]["input_ids"]))
315+
if tokens_available <= 0:
316+
# Remaining capacity is less than pad_sequences_to_be_divisible_by;
317+
# can't fit any tokens from this sample. Yield current batch and start fresh.
318+
if samples:
319+
yield samples
320+
samples = [sample]
321+
current_length = self._padded_len(sample_length)
322+
else:
323+
first_part, remaining_part = _split_sample_by_num_tokens(sample, tokens_available)
324+
yield [*samples, first_part]
325+
samples = [remaining_part]
326+
current_length = self._padded_len(len(samples[0]["input_ids"]))
309327
else:
310328
samples.append(sample)
311329

bionemo-recipes/models/esm2/src/esm/state.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ def apply_transforms(
6767
source: Union[nn.Module, _ModelState],
6868
target: TargetModuleT,
6969
mapping: Dict[str, str],
70-
transforms: Optional[List[Callable[[TransformCTX], TransformCTX]]] = [],
71-
state_dict_ignored_entries: List = [],
70+
transforms: Optional[List[Callable[[TransformCTX], TransformCTX]]] = None,
71+
state_dict_ignored_entries: Optional[List] = None,
7272
cast_dtype: Optional[torch.dtype] = None,
7373
) -> TargetModuleT:
7474
"""Transform the state dictionary of a source module to match the structure of a target module's state dictionary.
@@ -126,6 +126,11 @@ def scale_weights(ctx):
126126
This function is particularly useful when adapting models from different frameworks or
127127
when consolidating models with different architectural changes.
128128
"""
129+
if transforms is None:
130+
transforms = []
131+
if state_dict_ignored_entries is None:
132+
state_dict_ignored_entries = []
133+
129134
# Track dtypes to make sure they weren't modified during conversion.
130135
target_orig_dtypes = extract_dtypes(target.named_parameters())
131136

@@ -318,7 +323,7 @@ def __call__(self, ctx: TransformCTX) -> TransformCTX:
318323
try:
319324
source_match = source_matches[target_index]
320325
except IndexError as e:
321-
logger.error(f"Enountered IndexError during transform.\n{source_matches=}\n{target_matches=}")
326+
logger.error(f"Encountered IndexError during transform.\n{source_matches=}\n{target_matches=}")
322327
raise e
323328
if accepts_var_args:
324329
source_values = [source_dict[k] for k in source_match]

bionemo-recipes/models/esm2/tests/common/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Shared test infrastructure for BioNeMo models. One base class, **BaseModelTest**
44

55
## Structure
66

7-
```
7+
```text
88
tests/common/
99
├── __init__.py # Public API exports
1010
├── test_modeling_common.py # BaseModelTest, TestTolerances

bionemo-recipes/models/esm2/tests/common/__init__.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,6 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
17-
# SPDX-License-Identifier: Apache-2.0
18-
#
19-
# Licensed under the Apache License, Version 2.0 (the "License");
20-
# you may not use this file except in compliance with the License.
21-
# You may obtain a copy of the License at
22-
#
23-
# http://www.apache.org/licenses/LICENSE-2.0
24-
#
25-
# Unless required by applicable law or agreed to in writing, software
26-
# distributed under the License is distributed on an "AS IS" BASIS,
27-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28-
# See the License for the specific language governing permissions and
29-
# limitations under the License.
30-
3116
"""Common test utilities for BioNeMo models.
3217
3318
This package provides reusable test infrastructure following HuggingFace

bionemo-recipes/models/esm2/tests/common/fixtures.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,6 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
17-
# SPDX-License-Identifier: Apache-2.0
18-
#
19-
# Licensed under the Apache License, Version 2.0 (the "License");
20-
# you may not use this file except in compliance with the License.
21-
# You may obtain a copy of the License at
22-
#
23-
# http://www.apache.org/licenses/LICENSE-2.0
24-
#
25-
# Unless required by applicable law or agreed to in writing, software
26-
# distributed under the License is distributed on an "AS IS" BASIS,
27-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28-
# See the License for the specific language governing permissions and
29-
# limitations under the License.
30-
3116
"""Shared test fixtures for BioNeMo models."""
3217

3318
import os
@@ -63,7 +48,7 @@ def use_te_debug():
6348

6449
os.environ["NVTE_DEBUG"] = "1"
6550
yield
66-
del os.environ["NVTE_DEBUG"]
51+
os.environ.pop("NVTE_DEBUG", None)
6752

6853

6954
ALL_RECIPES = [
@@ -138,6 +123,6 @@ def te_attn_backend(request):
138123

139124
yield request.param
140125

141-
del os.environ["NVTE_FUSED_ATTN"]
142-
del os.environ["NVTE_FLASH_ATTN"]
126+
os.environ.pop("NVTE_FUSED_ATTN", None)
127+
os.environ.pop("NVTE_FLASH_ATTN", None)
143128
_attention_backends["backend_selection_requires_update"] = True

bionemo-recipes/models/esm2/tests/common/test_modeling_common.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
"""Common test class for BioNeMo models, following HuggingFace transformers patterns."""
1717

18+
import fnmatch
1819
import gc
1920
from abc import ABC, abstractmethod
2021
from dataclasses import dataclass
@@ -30,9 +31,12 @@
3031
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, set_seed
3132

3233

33-
HAS_DATA_CENTER_GPU = any(
34-
gpu_name in torch.cuda.get_device_name(0).upper() for gpu_name in ["H100", "H200", "B100", "B200", "B300"]
35-
)
34+
try:
35+
HAS_DATA_CENTER_GPU = torch.cuda.is_available() and any(
36+
gpu_name in torch.cuda.get_device_name(0).upper() for gpu_name in ["H100", "H200", "B100", "B200", "B300"]
37+
)
38+
except (RuntimeError, AssertionError):
39+
HAS_DATA_CENTER_GPU = False
3640

3741

3842
@dataclass
@@ -283,7 +287,9 @@ def msg(x):
283287
if should_be_fp8:
284288
if f"{name}.weight" in set(model._tied_weights_keys):
285289
continue # Skip tied weights
286-
elif hasattr(model, "_do_not_quantize") and name in model._do_not_quantize:
290+
elif hasattr(model, "_do_not_quantize") and any(
291+
fnmatch.fnmatch(name, pattern) for pattern in model._do_not_quantize
292+
):
287293
continue # Skip weights that should be kept in bf16
288294
assert isinstance(module.weight, QuantizedTensor), f"Module {name} weight is not a Float8Tensor"
289295

@@ -340,13 +346,14 @@ def get_reference_model(
340346
model.to("cuda")
341347
return model
342348

343-
def get_reference_model_no_weights(self) -> PreTrainedModel:
349+
def get_reference_model_no_weights(self, **kwargs) -> PreTrainedModel:
344350
"""Load the reference HuggingFace model with random weights."""
345351
return self.get_upstream_model_class()(
346352
AutoConfig.from_pretrained(
347353
self.get_upstream_model_id(),
348354
dtype=torch.float32,
349355
revision=self.get_upstream_model_revision(),
356+
**kwargs,
350357
)
351358
)
352359

bionemo-recipes/models/esm2/tests/test_collator.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -968,6 +968,100 @@ def __iter__(self):
968968
)
969969

970970

971+
def test_token_packing_dataset_padding_split_remaining_capacity_below_divisor():
972+
"""Test that split mode handles remaining capacity below pad_sequences_to_be_divisible_by.
973+
974+
When the remaining batch capacity (after rounding down to the pad divisor) is 0,
975+
the current batch must be yielded and the sample starts a new batch. Without this
976+
guard, _split_sample_by_num_tokens would be called with tokens_available=0 and crash.
977+
978+
max=12, pad=8, split=True:
979+
- s1: raw=5, padded=8. current=8 < 12. Append.
980+
- s2: raw=3, padded=8. current=8+8=16 > 12.
981+
tokens_in_batch=8, tokens_available=12-8=4, rounded to (4//8)*8=0 → yield [s1], fresh batch.
982+
- s3: raw=4, padded=8. current=8+8=16 > 12. Same: yield [s2], fresh batch.
983+
"""
984+
985+
class MockDataset(torch.utils.data.IterableDataset):
986+
def __iter__(self):
987+
yield {"input_ids": list(range(5))} # padded to 8
988+
yield {"input_ids": list(range(3))} # padded to 8
989+
yield {"input_ids": list(range(4))} # padded to 8
990+
991+
dataset = MockDataset()
992+
token_packing_dataset = TokenPackingDataset(
993+
dataset,
994+
max_tokens_per_batch=12,
995+
pad_sequences_to_be_divisible_by=8,
996+
split_samples=True,
997+
drop_last=False,
998+
)
999+
batches = list(token_packing_dataset)
1000+
1001+
# Each sample pads to 8; only one fits per batch (8 < 12, but 8+8=16 > 12,
1002+
# and remaining capacity 4 rounds down to 0 with pad=8).
1003+
assert len(batches) == 3
1004+
assert [len(s["input_ids"]) for s in batches[0]] == [5]
1005+
assert [len(s["input_ids"]) for s in batches[1]] == [3]
1006+
assert [len(s["input_ids"]) for s in batches[2]] == [4]
1007+
1008+
1009+
def test_token_packing_dataset_padding_no_split_yields_before_overflow():
1010+
"""Test that non-split mode correctly yields the batch before a padded sample overflows.
1011+
1012+
max=12, pad=8, split=False:
1013+
- s1: raw=5, padded=8. current=8 < 12. Append.
1014+
- s2: raw=3, padded=8. current=8+8=16 > 12. Yield [s1], start fresh with s2.
1015+
- s3: raw=4, padded=8. current=8+8=16 > 12. Yield [s2], start fresh with s3.
1016+
"""
1017+
1018+
class MockDataset(torch.utils.data.IterableDataset):
1019+
def __iter__(self):
1020+
yield {"input_ids": list(range(5))} # padded to 8
1021+
yield {"input_ids": list(range(3))} # padded to 8
1022+
yield {"input_ids": list(range(4))} # padded to 8
1023+
1024+
dataset = MockDataset()
1025+
token_packing_dataset = TokenPackingDataset(
1026+
dataset,
1027+
max_tokens_per_batch=12,
1028+
pad_sequences_to_be_divisible_by=8,
1029+
split_samples=False,
1030+
drop_last=False,
1031+
)
1032+
batches = list(token_packing_dataset)
1033+
1034+
# Each sample pads to 8, only one fits per batch (8 < 12, but 8+8=16 > 12)
1035+
assert len(batches) == 3
1036+
assert [len(s["input_ids"]) for s in batches[0]] == [5]
1037+
assert [len(s["input_ids"]) for s in batches[1]] == [3]
1038+
assert [len(s["input_ids"]) for s in batches[2]] == [4]
1039+
1040+
1041+
def test_token_packing_dataset_oversized_sample_raises():
1042+
"""Test that a sample exceeding max_tokens_per_batch raises a ValueError.
1043+
1044+
Users should set truncation or a maximum length in their tokenizer/dataset to ensure
1045+
all samples fit within max_tokens_per_batch.
1046+
"""
1047+
1048+
class MockDataset(torch.utils.data.IterableDataset):
1049+
def __iter__(self):
1050+
yield {"input_ids": list(range(5))} # fits
1051+
yield {"input_ids": list(range(25))} # exceeds max of 10
1052+
1053+
dataset = MockDataset()
1054+
token_packing_dataset = TokenPackingDataset(
1055+
dataset,
1056+
max_tokens_per_batch=10,
1057+
split_samples=False,
1058+
drop_last=False,
1059+
)
1060+
1061+
with pytest.raises(ValueError, match="Sample length.*exceeds max_tokens_per_batch"):
1062+
list(token_packing_dataset)
1063+
1064+
9711065
def test_token_packing_dataset_with_padding_split_drop_last_false(tokenizer):
9721066
"""Test that with drop_last=False, all batches except the last have exactly max_tokens."""
9731067
pad_divisor = 4

0 commit comments

Comments
 (0)