Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ jobs:
- name: Set up UV
uses: astral-sh/setup-uv@v1
with:
version: 0.7.2
version: 0.8.22
- name: Install ruff
env:
UV_PROJECT_ENVIRONMENT: ./venv
Expand All @@ -60,8 +60,9 @@ jobs:
- name: Run ruff
run: |
source ./venv/bin/activate
uv run ruff check . --verbose
uv run ruff format --check . --verbose
uv run --active ruff --version
uv run --active ruff check --verbose .
uv run --active ruff format --check --verbose .

import_linting:
runs-on: ubuntu-latest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ rng:
ranked: true

model:
_target_: nemo_automodel.components.models.llama.model.build_llama_model
_target_: nemo_automodel.NeMoAutoModelForCausalLM.from_pretrained
pretrained_model_name_or_path: meta-llama/Llama-3.3-70B-Instruct
torch_dtype: bf16

Expand Down Expand Up @@ -87,4 +87,4 @@ optimizer:

lr_scheduler:
lr_decay_style: cosine
min_lr: 1.0e-6
min_lr: 1.0e-6
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ rng:
ranked: true

model:
_target_: nemo_automodel.components.models.llama.model.build_llama_model
_target_: nemo_automodel.NeMoAutoModelForCausalLM.from_pretrained
pretrained_model_name_or_path: meta-llama/Llama-3.3-70B-Instruct
torch_dtype: bf16

Expand Down Expand Up @@ -87,4 +87,4 @@ optimizer:

lr_scheduler:
lr_decay_style: cosine
min_lr: 1.0e-6
min_lr: 1.0e-6
57 changes: 47 additions & 10 deletions nemo_automodel/_transformers/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import logging
import os
import types
from contextlib import contextmanager
from typing import List, Optional, Union

import torch
Expand All @@ -36,10 +37,7 @@
import nemo_automodel.components.distributed.utils as dist_utils
from nemo_automodel import __version__
from nemo_automodel._transformers.registry import ModelRegistry
from nemo_automodel.components.distributed.init_utils import (
get_local_world_size_preinit,
get_world_size_safe,
)
from nemo_automodel.components.distributed.init_utils import get_local_world_size_preinit, get_world_size_safe
from nemo_automodel.components.utils.model_utils import resolve_trust_remote_code
from nemo_automodel.shared.import_utils import safe_import
from nemo_automodel.shared.utils import dtype_from_str
Expand All @@ -49,6 +47,33 @@
logger = logging.getLogger(__name__)


@contextmanager
def local_torch_dtype(
dtype: torch.dtype, model_class_name: str | None = None, default_dtype: torch.dtype = torch.bfloat16
):
"""
Locally change the torch default dtype to `dtype`, and restore the old one upon exiting the context.
If `model_class_name` is provided, it's used to provide a more helpful error message if `dtype` is not valid.
"""
# Just a more helping error before we set `torch.set_default_dtype` later on which would crash in this case
if isinstance(dtype, str):
dtype = default_dtype
if not dtype.is_floating_point:
if model_class_name is not None:
error_message = (
f"{model_class_name} cannot be instantiated under `dtype={dtype}` as it's not a floating-point dtype"
)
else:
error_message = f"Cannot set `{dtype}` as torch's default as it's not a floating-point dtype"
raise ValueError(error_message)
original_dtype = torch.get_default_dtype()
try:
torch.set_default_dtype(dtype)
yield
finally:
torch.set_default_dtype(original_dtype)


def _assert_same_signature(original, patched):
"""
Raise AssertionError if the two call signatures differ.
Expand Down Expand Up @@ -157,15 +182,17 @@ def _get_next_fallback_attn(attn_implementation: str) -> str:
return priorities[0]


def _prepare_hf_config_and_flag(pretrained_model_name_or_path, force_hf, kwargs):
def _prepare_hf_config_and_flag(pretrained_model_name_or_path, force_hf, kwargs, attn_implementation):
"""
Resolve trust_remote_code default, fetch HF config and determine if model is HF-based.
"""
kwargs["trust_remote_code"] = kwargs.get(
"trust_remote_code", resolve_trust_remote_code(pretrained_model_name_or_path)
)
hf_config = kwargs.pop("config", None) or AutoConfig.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=kwargs["trust_remote_code"]
pretrained_model_name_or_path,
**kwargs,
attn_implementation=attn_implementation,
)
architectures = getattr(hf_config, "architectures", None) or []
is_hf_model = (not architectures or architectures[0] not in ModelRegistry.model_arch_name_to_cls) or force_hf
Expand Down Expand Up @@ -358,7 +385,9 @@ def from_pretrained(
`use_liger_kernel=False` or `use_sdpa_patching=False`
"""
torch_dtype = dtype_from_str(torch_dtype) if torch_dtype != "auto" else torch_dtype
hf_config, is_hf_model = _prepare_hf_config_and_flag(pretrained_model_name_or_path, force_hf, kwargs)
hf_config, is_hf_model = _prepare_hf_config_and_flag(
pretrained_model_name_or_path, force_hf, kwargs, attn_implementation=attn_implementation
)
tp_size, cp_size, has_packed_sequence = _pop_tp_cp_has_packed(kwargs)
attn_implementation, use_liger_kernel = _apply_preload_overrides(
is_hf_model, tp_size, cp_size, has_packed_sequence, attn_implementation, use_liger_kernel
Expand Down Expand Up @@ -400,7 +429,10 @@ def _retry(**override):
_download_model_weights(hf_config, pretrained_model_name_or_path)
logger.info(f"Using custom model implementation for {architectures[0]}")
kwargs.pop("trust_remote_code", None)
return ModelRegistry.model_arch_name_to_cls[architectures[0]](hf_config, *model_args, **kwargs)
# TODO(@akoumpa): restore weights after initialization.
model_cls = ModelRegistry.model_arch_name_to_cls[architectures[0]]
with local_torch_dtype(torch_dtype, model_cls.__name__):
return model_cls(hf_config)

# 3. fallback to parent class
model = None
Expand Down Expand Up @@ -533,7 +565,11 @@ def _retry(**override):

# handle model_id passed as config
if isinstance(config, str):
config = AutoConfig.from_pretrained(config, trust_remote_code=kwargs.get("trust_remote_code", False))
config = AutoConfig.from_pretrained(
config,
trust_remote_code=kwargs.get("trust_remote_code", False),
attn_implementation=attn_implementation,
)
# 1. if force_hf is True, we will use the parent class to load and return the model as is
if force_hf:
return cls._from_config_parent_class(
Expand All @@ -547,7 +583,8 @@ def _retry(**override):
# 2. If we have a custom model implementation available, we prioritize that over HF
architectures = get_architectures(config)
if len(architectures) > 0 and architectures[0] in ModelRegistry.model_arch_name_to_cls:
return ModelRegistry.model_arch_name_to_cls[architectures[0]](config, *model_args, **kwargs)
with local_torch_dtype(torch_dtype, ModelRegistry.model_arch_name_to_cls[architectures[0]].__name__):
return ModelRegistry.model_arch_name_to_cls[architectures[0]](config)

# 3. fallback to parent class
model = None
Expand Down
6 changes: 0 additions & 6 deletions nemo_automodel/components/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Custom Llama model implementation for NeMo Automodel."""

from nemo_automodel.components.models.llama.model import LlamaForCausalLM, build_llama_model

__all__ = ["LlamaForCausalLM", "build_llama_model"]
111 changes: 5 additions & 106 deletions nemo_automodel/components/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,18 @@

```yaml
model:
_target_: nemo_automodel.components.models.llama.build_llama_model
_target_: nemo_automodel.NeMoAutoModelForCausalLM.from_pretrained
pretrained_model_name_or_path: meta-llama/Llama-3.3-70B-Instruct
```
"""

from __future__ import annotations

import os
from typing import Any, Callable, Optional, Union
from typing import Callable, Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import LlamaConfig
from transformers.cache_utils import Cache, DynamicCache
from transformers.masking_utils import create_causal_mask
Expand All @@ -51,16 +50,9 @@
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs, can_return_tuple

from nemo_automodel.components.models.common.combined_projection import (
CombinedGateUpMLP,
CombinedQKVAttentionMixin,
)
from nemo_automodel.components.models.common.combined_projection import CombinedGateUpMLP, CombinedQKVAttentionMixin
from nemo_automodel.components.models.llama.state_dict_adapter import LlamaStateDictAdapter
from nemo_automodel.components.moe.utils import BackendConfig
from nemo_automodel.shared.import_utils import get_check_model_inputs_decorator
from nemo_automodel.shared.utils import dtype_from_str

__all__ = ["build_llama_model", "LlamaForCausalLM"]

check_model_inputs = get_check_model_inputs_decorator()

Expand Down Expand Up @@ -360,19 +352,15 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
def __init__(
self,
config: LlamaConfig,
backend: Optional[BackendConfig] = None,
):
super().__init__(config)
self.config = config
self.backend = backend or BackendConfig()
self.model = LlamaModel(config=config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

# Create state_dict_adapter if enabled
if self.backend.enable_hf_state_dict_adapter:
self.state_dict_adapter = LlamaStateDictAdapter(config=self.config)

# Create state_dict_adapter
self.state_dict_adapter = LlamaStateDictAdapter(config=self.config)
# Initialize weights and apply final processing
self.post_init()

Expand Down Expand Up @@ -490,93 +478,4 @@ def forward(
)


def build_llama_model(pretrained_model_name_or_path: str, **kwargs: Any) -> nn.Module:
"""Build a custom Llama model with combined projections for efficiency.

This function loads the config from a HuggingFace model card and builds
a custom Llama model with combined QKV and gate_up projections for improved efficiency.

Args:
pretrained_model_name_or_path: HuggingFace model card name (e.g., "meta-llama/Meta-Llama-3-70B")
**kwargs: Override config parameters. Common parameters include:
- vocab_size: Vocabulary size
- hidden_size: Hidden dimension size
- num_hidden_layers: Number of transformer layers (useful for testing)
- num_attention_heads: Number of attention heads
- num_key_value_heads: Number of key/value heads for GQA
- intermediate_size: MLP intermediate size
- max_position_embeddings: Maximum sequence length
- rms_norm_eps: RMSNorm epsilon
- rope_theta: RoPE base frequency
- attention_dropout: Attention dropout probability
- pad_token_id: Padding token ID
- attn_implementation: Attention backend ("eager", "sdpa", "flash_attention_2")
- torch_dtype: Model dtype (default: bfloat16)

Returns:
LlamaForCausalLM model instance with combined projections

Example:
# Load with default settings (combined projections, bfloat16)
model = build_llama_model("meta-llama/Meta-Llama-3-70B")

# Use SDPA for faster attention
model = build_llama_model("meta-llama/Meta-Llama-3-70B",
attn_implementation="sdpa")

# Override for testing with fewer layers
model = build_llama_model("meta-llama/Meta-Llama-3-70B", num_hidden_layers=4)
"""
# Extract and convert torch_dtype
torch_dtype = kwargs.pop("torch_dtype", None)
if torch_dtype is not None and isinstance(torch_dtype, str):
torch_dtype = dtype_from_str(torch_dtype)
elif torch_dtype is None:
torch_dtype = torch.bfloat16 # Default to bf16

# Extract attention implementation if specified, otherwise auto-detect
# This matches nemo_automodel/_transformers/auto_model.py approach
attn_implementation = kwargs.pop("attn_implementation", None)

# Load config from HuggingFace (with any overrides from kwargs)
config = LlamaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)

# Ensure architectures is set for LoRA compatibility
if not hasattr(config, "architectures") or config.architectures is None:
config.architectures = ["LlamaForCausalLM"]

# Set attention implementation with auto-detection
# Priority: user-specified > existing in config > auto-detect (flash_attention_2 > sdpa > eager)
# This matches the logic in nemo_automodel/_transformers/auto_model.py
if attn_implementation is not None:
config._attn_implementation = attn_implementation
elif not hasattr(config, "_attn_implementation") or config._attn_implementation is None:
# Auto-detect best available implementation (same as nemo_automodel default)
try:
# Try flash_attention_2 first (fastest)
config._attn_implementation = "flash_attention_2"
except (ImportError, ModuleNotFoundError):
# Fall back to SDPA if available (PyTorch 2.0+)
if hasattr(F, "scaled_dot_product_attention"):
config._attn_implementation = "sdpa"
else:
# Final fallback to eager
config._attn_implementation = "eager"

if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
print(f"[build_llama_model] Attention implementation: {config._attn_implementation}")
print(f"[build_llama_model] torch_dtype: {torch_dtype}")

# Create backend config with HF state dict adapter enabled
backend = BackendConfig(enable_hf_state_dict_adapter=True)

# Create model with combined projections
model = LlamaForCausalLM(config=config, backend=backend)

# need to convert model manually since LlamaForCausalLM does not support to(dtype=...)
model = model.to(dtype=torch_dtype)

return model


ModelClass = LlamaForCausalLM
8 changes: 7 additions & 1 deletion tests/unit_tests/_transformers/test_auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def test_from_pretrained_uses_registry_when_available(self):
# Prepare a fake custom model class and return value
custom_model_instance = Mock()
custom_cls = Mock(return_value=custom_model_instance)
custom_cls.__name__ = "MockMockMock"
mock_registry.model_arch_name_to_cls = {"CustomArch": custom_cls}

returned = NeMoAutoModelForCausalLM.from_pretrained("dummy/path")
Expand Down Expand Up @@ -130,6 +131,7 @@ def test_from_config_uses_registry_when_available(self):
# Registry provides a custom class
custom_model_instance = Mock()
custom_cls = Mock(return_value=custom_model_instance)
custom_cls.__name__ = "MockMockMock"
mock_registry.model_arch_name_to_cls = {"CustomArch": custom_cls}

returned = NeMoAutoModelForCausalLM.from_config(cfg)
Expand Down Expand Up @@ -160,6 +162,7 @@ def test_from_pretrained_registry_downloads_checkpoint_files_rank0(self):
# Prepare a fake custom model class and return value
custom_model_instance = Mock()
custom_cls = Mock(return_value=custom_model_instance)
custom_cls.__name__ = "MockMockMock"
mock_registry.model_arch_name_to_cls = {"CustomArch": custom_cls}

returned = NeMoAutoModelForCausalLM.from_pretrained("dummy/repo-id")
Expand Down Expand Up @@ -194,6 +197,7 @@ def test_from_pretrained_registry_downloads_when_dist_uninitialized(self):
# Prepare a fake custom model class and return value
custom_model_instance = Mock()
custom_cls = Mock(return_value=custom_model_instance)
custom_cls.__name__ = "MockMockMock"
mock_registry.model_arch_name_to_cls = {"CustomArch": custom_cls}

returned = NeMoAutoModelForCausalLM.from_pretrained("dummy/repo-id")
Expand Down Expand Up @@ -240,7 +244,8 @@ def test_from_config_with_string_calls_autoconfig(self):
# Verify AutoConfig.from_pretrained was called with the string
mock_autoconfig.assert_called_once_with(
"hf-internal-testing/tiny-random-gpt2",
trust_remote_code=False
trust_remote_code=False,
attn_implementation="flash_attention_2",
)
# Verify the model was returned
assert model is mock_model
Expand Down Expand Up @@ -539,6 +544,7 @@ def test_packed_sequence_and_cp_overrides_from_pretrained(
else:
custom_model_instance = Mock()
custom_cls = Mock(return_value=custom_model_instance)
custom_cls.__name__ = "MockMockMock"
mock_registry.model_arch_name_to_cls = {"CustomArch": custom_cls}

mock_hf_loader.return_value = MagicMock(config={})
Expand Down
Loading
Loading