Skip to content

Commit 24d1e2c

Browse files
committed
rebase on moe fix pr
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 11b7a7f commit 24d1e2c

File tree

6 files changed

+55
-55
lines changed

6 files changed

+55
-55
lines changed

bionemo-recipes/models/mixtral/collator.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,11 @@ def __call__(self, features, return_tensors=None):
156156
sequence processing capabilities. When pad_to_multiple_of is used, an additional
157157
mock sequence is appended to reach the desired total length.
158158
"""
159+
if return_tensors is not None and return_tensors != "pt":
160+
raise NotImplementedError(f"Only return_tensors='pt' is supported, got '{return_tensors}'")
161+
159162
# Perform the masking with the BSHD collator.
160-
bshd_batch = self.collator(features)
163+
bshd_batch = self.collator(features, return_tensors=return_tensors)
161164

162165
# Create the flattened batch to get the cu_seq_lens_q and cu_seq_lens_k values.
163166
packed_batch = _pt_flatten_collate(features, return_position_ids=self.return_position_ids)
@@ -279,33 +282,48 @@ def __iter__(self):
279282
samples = []
280283
current_length = 0
281284
for sample in iter(self.dataset):
282-
current_length += self._padded_len(len(sample["input_ids"]))
285+
sample_length = len(sample["input_ids"])
286+
if sample_length > self.max_tokens_per_batch:
287+
raise ValueError(
288+
f"TokenPackingDataset: Sample length ({sample_length}) exceeds max_tokens_per_batch "
289+
f"({self.max_tokens_per_batch}). Set truncation or a maximum length in your tokenizer or dataset to"
290+
"ensure all samples fit within max_tokens_per_batch."
291+
)
292+
293+
current_length += self._padded_len(sample_length)
283294
if current_length == self.max_tokens_per_batch:
284295
yield [*samples, sample]
285296
samples = []
286297
current_length = 0
287298

288299
elif current_length > self.max_tokens_per_batch:
289300
if not self.split_samples:
290-
# If we are not splitting samples, we can just yield the current batch (before this sample) and
291-
# start a new one.
292-
yield samples
301+
# Yield the current batch (before this sample) and start a new one with this sample.
302+
if samples:
303+
yield samples
293304
samples = [sample]
294-
305+
current_length = self._padded_len(sample_length)
295306
else:
296-
# Calculate how many padded tokens are already in the batch
297-
tokens_in_batch = current_length - self._padded_len(len(sample["input_ids"]))
307+
# Calculate how many padded tokens are already in the batch.
308+
tokens_in_batch = current_length - self._padded_len(sample_length)
298309
# Calculate how many tokens we can fit from this sample, ensuring the
299310
# padded length doesn't exceed the remaining capacity.
300311
tokens_available = self.max_tokens_per_batch - tokens_in_batch
301312
if self.pad_sequences_to_be_divisible_by is not None:
302313
d = self.pad_sequences_to_be_divisible_by
303314
tokens_available = (tokens_available // d) * d
304-
first_part, remaining_part = _split_sample_by_num_tokens(sample, tokens_available)
305-
yield [*samples, first_part]
306-
samples = [remaining_part]
307-
308-
current_length = self._padded_len(len(samples[0]["input_ids"]))
315+
if tokens_available <= 0:
316+
# Remaining capacity is less than pad_sequences_to_be_divisible_by;
317+
# can't fit any tokens from this sample. Yield current batch and start fresh.
318+
if samples:
319+
yield samples
320+
samples = [sample]
321+
current_length = self._padded_len(sample_length)
322+
else:
323+
first_part, remaining_part = _split_sample_by_num_tokens(sample, tokens_available)
324+
yield [*samples, first_part]
325+
samples = [remaining_part]
326+
current_length = self._padded_len(len(samples[0]["input_ids"]))
309327
else:
310328
samples.append(sample)
311329

bionemo-recipes/models/mixtral/state.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ def apply_transforms(
6767
source: Union[nn.Module, _ModelState],
6868
target: TargetModuleT,
6969
mapping: Dict[str, str],
70-
transforms: Optional[List[Callable[[TransformCTX], TransformCTX]]] = [],
71-
state_dict_ignored_entries: List = [],
70+
transforms: Optional[List[Callable[[TransformCTX], TransformCTX]]] = None,
71+
state_dict_ignored_entries: Optional[List] = None,
7272
cast_dtype: Optional[torch.dtype] = None,
7373
) -> TargetModuleT:
7474
"""Transform the state dictionary of a source module to match the structure of a target module's state dictionary.
@@ -126,6 +126,11 @@ def scale_weights(ctx):
126126
This function is particularly useful when adapting models from different frameworks or
127127
when consolidating models with different architectural changes.
128128
"""
129+
if transforms is None:
130+
transforms = []
131+
if state_dict_ignored_entries is None:
132+
state_dict_ignored_entries = []
133+
129134
# Track dtypes to make sure they weren't modified during conversion.
130135
target_orig_dtypes = extract_dtypes(target.named_parameters())
131136

@@ -318,7 +323,7 @@ def __call__(self, ctx: TransformCTX) -> TransformCTX:
318323
try:
319324
source_match = source_matches[target_index]
320325
except IndexError as e:
321-
logger.error(f"Enountered IndexError during transform.\n{source_matches=}\n{target_matches=}")
326+
logger.error(f"Encountered IndexError during transform.\n{source_matches=}\n{target_matches=}")
322327
raise e
323328
if accepts_var_args:
324329
source_values = [source_dict[k] for k in source_match]

bionemo-recipes/models/mixtral/tests/common/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Shared test infrastructure for BioNeMo models. One base class, **BaseModelTest**
44

55
## Structure
66

7-
```
7+
```text
88
tests/common/
99
├── __init__.py # Public API exports
1010
├── test_modeling_common.py # BaseModelTest, TestTolerances

bionemo-recipes/models/mixtral/tests/common/__init__.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,6 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
17-
# SPDX-License-Identifier: Apache-2.0
18-
#
19-
# Licensed under the Apache License, Version 2.0 (the "License");
20-
# you may not use this file except in compliance with the License.
21-
# You may obtain a copy of the License at
22-
#
23-
# http://www.apache.org/licenses/LICENSE-2.0
24-
#
25-
# Unless required by applicable law or agreed to in writing, software
26-
# distributed under the License is distributed on an "AS IS" BASIS,
27-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28-
# See the License for the specific language governing permissions and
29-
# limitations under the License.
30-
3116
"""Common test utilities for BioNeMo models.
3217
3318
This package provides reusable test infrastructure following HuggingFace

bionemo-recipes/models/mixtral/tests/common/fixtures.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,6 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
17-
# SPDX-License-Identifier: Apache-2.0
18-
#
19-
# Licensed under the Apache License, Version 2.0 (the "License");
20-
# you may not use this file except in compliance with the License.
21-
# You may obtain a copy of the License at
22-
#
23-
# http://www.apache.org/licenses/LICENSE-2.0
24-
#
25-
# Unless required by applicable law or agreed to in writing, software
26-
# distributed under the License is distributed on an "AS IS" BASIS,
27-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28-
# See the License for the specific language governing permissions and
29-
# limitations under the License.
30-
3116
"""Shared test fixtures for BioNeMo models."""
3217

3318
import os
@@ -63,7 +48,7 @@ def use_te_debug():
6348

6449
os.environ["NVTE_DEBUG"] = "1"
6550
yield
66-
del os.environ["NVTE_DEBUG"]
51+
os.environ.pop("NVTE_DEBUG", None)
6752

6853

6954
ALL_RECIPES = [
@@ -138,6 +123,6 @@ def te_attn_backend(request):
138123

139124
yield request.param
140125

141-
del os.environ["NVTE_FUSED_ATTN"]
142-
del os.environ["NVTE_FLASH_ATTN"]
126+
os.environ.pop("NVTE_FUSED_ATTN", None)
127+
os.environ.pop("NVTE_FLASH_ATTN", None)
143128
_attention_backends["backend_selection_requires_update"] = True

bionemo-recipes/models/mixtral/tests/common/test_modeling_common.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
"""Common test class for BioNeMo models, following HuggingFace transformers patterns."""
1717

18+
import fnmatch
1819
import gc
1920
from abc import ABC, abstractmethod
2021
from dataclasses import dataclass
@@ -30,9 +31,12 @@
3031
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, set_seed
3132

3233

33-
HAS_DATA_CENTER_GPU = any(
34-
gpu_name in torch.cuda.get_device_name(0).upper() for gpu_name in ["H100", "H200", "B100", "B200", "B300"]
35-
)
34+
try:
35+
HAS_DATA_CENTER_GPU = torch.cuda.is_available() and any(
36+
gpu_name in torch.cuda.get_device_name(0).upper() for gpu_name in ["H100", "H200", "B100", "B200", "B300"]
37+
)
38+
except (RuntimeError, AssertionError):
39+
HAS_DATA_CENTER_GPU = False
3640

3741

3842
@dataclass
@@ -283,7 +287,9 @@ def msg(x):
283287
if should_be_fp8:
284288
if f"{name}.weight" in set(model._tied_weights_keys):
285289
continue # Skip tied weights
286-
elif hasattr(model, "_do_not_quantize") and name in model._do_not_quantize:
290+
elif hasattr(model, "_do_not_quantize") and any(
291+
fnmatch.fnmatch(name, pattern) for pattern in model._do_not_quantize
292+
):
287293
continue # Skip weights that should be kept in bf16
288294
assert isinstance(module.weight, QuantizedTensor), f"Module {name} weight is not a Float8Tensor"
289295

@@ -340,13 +346,14 @@ def get_reference_model(
340346
model.to("cuda")
341347
return model
342348

343-
def get_reference_model_no_weights(self) -> PreTrainedModel:
349+
def get_reference_model_no_weights(self, **kwargs) -> PreTrainedModel:
344350
"""Load the reference HuggingFace model with random weights."""
345351
return self.get_upstream_model_class()(
346352
AutoConfig.from_pretrained(
347353
self.get_upstream_model_id(),
348354
dtype=torch.float32,
349355
revision=self.get_upstream_model_revision(),
356+
**kwargs,
350357
)
351358
)
352359

0 commit comments

Comments
 (0)