diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index a83d6d43d..30bc82b6c 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -45,7 +45,7 @@ jobs: - name: Set up UV uses: astral-sh/setup-uv@v1 with: - version: 0.7.2 + version: 0.8.22 - name: Install ruff env: UV_PROJECT_ENVIRONMENT: ./venv @@ -60,8 +60,9 @@ jobs: - name: Run ruff run: | source ./venv/bin/activate - uv run ruff check . --verbose - uv run ruff format --check . --verbose + uv run --active ruff --version + uv run --active ruff check --verbose . + uv run --active ruff format --check --verbose . import_linting: runs-on: ubuntu-latest diff --git a/examples/llm_finetune/llama3_3/custom_llama3_3_70b_instruct_peft_benchmark.yaml b/examples/llm_finetune/llama3_3/custom_llama3_3_70b_instruct_peft_benchmark.yaml index e6b66d57c..4490f0a13 100644 --- a/examples/llm_finetune/llama3_3/custom_llama3_3_70b_instruct_peft_benchmark.yaml +++ b/examples/llm_finetune/llama3_3/custom_llama3_3_70b_instruct_peft_benchmark.yaml @@ -28,7 +28,7 @@ rng: ranked: true model: - _target_: nemo_automodel.components.models.llama.model.build_llama_model + _target_: nemo_automodel.NeMoAutoModelForCausalLM.from_pretrained pretrained_model_name_or_path: meta-llama/Llama-3.3-70B-Instruct torch_dtype: bf16 @@ -87,4 +87,4 @@ optimizer: lr_scheduler: lr_decay_style: cosine - min_lr: 1.0e-6 \ No newline at end of file + min_lr: 1.0e-6 diff --git a/examples/llm_finetune/llama3_3/custom_llama3_3_70b_instruct_peft_benchmark_2nodes.yaml b/examples/llm_finetune/llama3_3/custom_llama3_3_70b_instruct_peft_benchmark_2nodes.yaml index 89dae3557..266455c29 100644 --- a/examples/llm_finetune/llama3_3/custom_llama3_3_70b_instruct_peft_benchmark_2nodes.yaml +++ b/examples/llm_finetune/llama3_3/custom_llama3_3_70b_instruct_peft_benchmark_2nodes.yaml @@ -28,7 +28,7 @@ rng: ranked: true model: - _target_: nemo_automodel.components.models.llama.model.build_llama_model + _target_: nemo_automodel.NeMoAutoModelForCausalLM.from_pretrained pretrained_model_name_or_path: meta-llama/Llama-3.3-70B-Instruct torch_dtype: bf16 @@ -87,4 +87,4 @@ optimizer: lr_scheduler: lr_decay_style: cosine - min_lr: 1.0e-6 \ No newline at end of file + min_lr: 1.0e-6 diff --git a/nemo_automodel/_transformers/auto_model.py b/nemo_automodel/_transformers/auto_model.py index 061fae6a1..8dc81eeb0 100644 --- a/nemo_automodel/_transformers/auto_model.py +++ b/nemo_automodel/_transformers/auto_model.py @@ -18,6 +18,7 @@ import logging import os import types +from contextlib import contextmanager from typing import List, Optional, Union import torch @@ -36,10 +37,7 @@ import nemo_automodel.components.distributed.utils as dist_utils from nemo_automodel import __version__ from nemo_automodel._transformers.registry import ModelRegistry -from nemo_automodel.components.distributed.init_utils import ( - get_local_world_size_preinit, - get_world_size_safe, -) +from nemo_automodel.components.distributed.init_utils import get_local_world_size_preinit, get_world_size_safe from nemo_automodel.components.utils.model_utils import resolve_trust_remote_code from nemo_automodel.shared.import_utils import safe_import from nemo_automodel.shared.utils import dtype_from_str @@ -49,6 +47,33 @@ logger = logging.getLogger(__name__) +@contextmanager +def local_torch_dtype( + dtype: torch.dtype, model_class_name: str | None = None, default_dtype: torch.dtype = torch.bfloat16 +): + """ + Locally change the torch default dtype to `dtype`, and restore the old one upon exiting the context. + If `model_class_name` is provided, it's used to provide a more helpful error message if `dtype` is not valid. + """ + # Just a more helping error before we set `torch.set_default_dtype` later on which would crash in this case + if isinstance(dtype, str): + dtype = default_dtype + if not dtype.is_floating_point: + if model_class_name is not None: + error_message = ( + f"{model_class_name} cannot be instantiated under `dtype={dtype}` as it's not a floating-point dtype" + ) + else: + error_message = f"Cannot set `{dtype}` as torch's default as it's not a floating-point dtype" + raise ValueError(error_message) + original_dtype = torch.get_default_dtype() + try: + torch.set_default_dtype(dtype) + yield + finally: + torch.set_default_dtype(original_dtype) + + def _assert_same_signature(original, patched): """ Raise AssertionError if the two call signatures differ. @@ -157,7 +182,7 @@ def _get_next_fallback_attn(attn_implementation: str) -> str: return priorities[0] -def _prepare_hf_config_and_flag(pretrained_model_name_or_path, force_hf, kwargs): +def _prepare_hf_config_and_flag(pretrained_model_name_or_path, force_hf, kwargs, attn_implementation): """ Resolve trust_remote_code default, fetch HF config and determine if model is HF-based. """ @@ -165,7 +190,9 @@ def _prepare_hf_config_and_flag(pretrained_model_name_or_path, force_hf, kwargs) "trust_remote_code", resolve_trust_remote_code(pretrained_model_name_or_path) ) hf_config = kwargs.pop("config", None) or AutoConfig.from_pretrained( - pretrained_model_name_or_path, trust_remote_code=kwargs["trust_remote_code"] + pretrained_model_name_or_path, + **kwargs, + attn_implementation=attn_implementation, ) architectures = getattr(hf_config, "architectures", None) or [] is_hf_model = (not architectures or architectures[0] not in ModelRegistry.model_arch_name_to_cls) or force_hf @@ -358,7 +385,9 @@ def from_pretrained( `use_liger_kernel=False` or `use_sdpa_patching=False` """ torch_dtype = dtype_from_str(torch_dtype) if torch_dtype != "auto" else torch_dtype - hf_config, is_hf_model = _prepare_hf_config_and_flag(pretrained_model_name_or_path, force_hf, kwargs) + hf_config, is_hf_model = _prepare_hf_config_and_flag( + pretrained_model_name_or_path, force_hf, kwargs, attn_implementation=attn_implementation + ) tp_size, cp_size, has_packed_sequence = _pop_tp_cp_has_packed(kwargs) attn_implementation, use_liger_kernel = _apply_preload_overrides( is_hf_model, tp_size, cp_size, has_packed_sequence, attn_implementation, use_liger_kernel @@ -400,7 +429,10 @@ def _retry(**override): _download_model_weights(hf_config, pretrained_model_name_or_path) logger.info(f"Using custom model implementation for {architectures[0]}") kwargs.pop("trust_remote_code", None) - return ModelRegistry.model_arch_name_to_cls[architectures[0]](hf_config, *model_args, **kwargs) + # TODO(@akoumpa): restore weights after initialization. + model_cls = ModelRegistry.model_arch_name_to_cls[architectures[0]] + with local_torch_dtype(torch_dtype, model_cls.__name__): + return model_cls(hf_config) # 3. fallback to parent class model = None @@ -533,7 +565,11 @@ def _retry(**override): # handle model_id passed as config if isinstance(config, str): - config = AutoConfig.from_pretrained(config, trust_remote_code=kwargs.get("trust_remote_code", False)) + config = AutoConfig.from_pretrained( + config, + trust_remote_code=kwargs.get("trust_remote_code", False), + attn_implementation=attn_implementation, + ) # 1. if force_hf is True, we will use the parent class to load and return the model as is if force_hf: return cls._from_config_parent_class( @@ -547,7 +583,8 @@ def _retry(**override): # 2. If we have a custom model implementation available, we prioritize that over HF architectures = get_architectures(config) if len(architectures) > 0 and architectures[0] in ModelRegistry.model_arch_name_to_cls: - return ModelRegistry.model_arch_name_to_cls[architectures[0]](config, *model_args, **kwargs) + with local_torch_dtype(torch_dtype, ModelRegistry.model_arch_name_to_cls[architectures[0]].__name__): + return ModelRegistry.model_arch_name_to_cls[architectures[0]](config) # 3. fallback to parent class model = None diff --git a/nemo_automodel/components/models/llama/__init__.py b/nemo_automodel/components/models/llama/__init__.py index dc373bb9f..070b8c0d7 100644 --- a/nemo_automodel/components/models/llama/__init__.py +++ b/nemo_automodel/components/models/llama/__init__.py @@ -11,9 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -"""Custom Llama model implementation for NeMo Automodel.""" - -from nemo_automodel.components.models.llama.model import LlamaForCausalLM, build_llama_model - -__all__ = ["LlamaForCausalLM", "build_llama_model"] diff --git a/nemo_automodel/components/models/llama/model.py b/nemo_automodel/components/models/llama/model.py index 4e164bd77..000fe921f 100644 --- a/nemo_automodel/components/models/llama/model.py +++ b/nemo_automodel/components/models/llama/model.py @@ -21,7 +21,7 @@ ```yaml model: - _target_: nemo_automodel.components.models.llama.build_llama_model + _target_: nemo_automodel.NeMoAutoModelForCausalLM.from_pretrained pretrained_model_name_or_path: meta-llama/Llama-3.3-70B-Instruct ``` """ @@ -29,11 +29,10 @@ from __future__ import annotations import os -from typing import Any, Callable, Optional, Union +from typing import Callable, Optional, Union import torch import torch.nn as nn -import torch.nn.functional as F from transformers import LlamaConfig from transformers.cache_utils import Cache, DynamicCache from transformers.masking_utils import create_causal_mask @@ -51,16 +50,9 @@ from transformers.processing_utils import Unpack from transformers.utils import TransformersKwargs, can_return_tuple -from nemo_automodel.components.models.common.combined_projection import ( - CombinedGateUpMLP, - CombinedQKVAttentionMixin, -) +from nemo_automodel.components.models.common.combined_projection import CombinedGateUpMLP, CombinedQKVAttentionMixin from nemo_automodel.components.models.llama.state_dict_adapter import LlamaStateDictAdapter -from nemo_automodel.components.moe.utils import BackendConfig from nemo_automodel.shared.import_utils import get_check_model_inputs_decorator -from nemo_automodel.shared.utils import dtype_from_str - -__all__ = ["build_llama_model", "LlamaForCausalLM"] check_model_inputs = get_check_model_inputs_decorator() @@ -360,19 +352,15 @@ class LlamaForCausalLM(LlamaPreTrainedModel): def __init__( self, config: LlamaConfig, - backend: Optional[BackendConfig] = None, ): super().__init__(config) self.config = config - self.backend = backend or BackendConfig() self.model = LlamaModel(config=config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - # Create state_dict_adapter if enabled - if self.backend.enable_hf_state_dict_adapter: - self.state_dict_adapter = LlamaStateDictAdapter(config=self.config) - + # Create state_dict_adapter + self.state_dict_adapter = LlamaStateDictAdapter(config=self.config) # Initialize weights and apply final processing self.post_init() @@ -490,93 +478,4 @@ def forward( ) -def build_llama_model(pretrained_model_name_or_path: str, **kwargs: Any) -> nn.Module: - """Build a custom Llama model with combined projections for efficiency. - - This function loads the config from a HuggingFace model card and builds - a custom Llama model with combined QKV and gate_up projections for improved efficiency. - - Args: - pretrained_model_name_or_path: HuggingFace model card name (e.g., "meta-llama/Meta-Llama-3-70B") - **kwargs: Override config parameters. Common parameters include: - - vocab_size: Vocabulary size - - hidden_size: Hidden dimension size - - num_hidden_layers: Number of transformer layers (useful for testing) - - num_attention_heads: Number of attention heads - - num_key_value_heads: Number of key/value heads for GQA - - intermediate_size: MLP intermediate size - - max_position_embeddings: Maximum sequence length - - rms_norm_eps: RMSNorm epsilon - - rope_theta: RoPE base frequency - - attention_dropout: Attention dropout probability - - pad_token_id: Padding token ID - - attn_implementation: Attention backend ("eager", "sdpa", "flash_attention_2") - - torch_dtype: Model dtype (default: bfloat16) - - Returns: - LlamaForCausalLM model instance with combined projections - - Example: - # Load with default settings (combined projections, bfloat16) - model = build_llama_model("meta-llama/Meta-Llama-3-70B") - - # Use SDPA for faster attention - model = build_llama_model("meta-llama/Meta-Llama-3-70B", - attn_implementation="sdpa") - - # Override for testing with fewer layers - model = build_llama_model("meta-llama/Meta-Llama-3-70B", num_hidden_layers=4) - """ - # Extract and convert torch_dtype - torch_dtype = kwargs.pop("torch_dtype", None) - if torch_dtype is not None and isinstance(torch_dtype, str): - torch_dtype = dtype_from_str(torch_dtype) - elif torch_dtype is None: - torch_dtype = torch.bfloat16 # Default to bf16 - - # Extract attention implementation if specified, otherwise auto-detect - # This matches nemo_automodel/_transformers/auto_model.py approach - attn_implementation = kwargs.pop("attn_implementation", None) - - # Load config from HuggingFace (with any overrides from kwargs) - config = LlamaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) - - # Ensure architectures is set for LoRA compatibility - if not hasattr(config, "architectures") or config.architectures is None: - config.architectures = ["LlamaForCausalLM"] - - # Set attention implementation with auto-detection - # Priority: user-specified > existing in config > auto-detect (flash_attention_2 > sdpa > eager) - # This matches the logic in nemo_automodel/_transformers/auto_model.py - if attn_implementation is not None: - config._attn_implementation = attn_implementation - elif not hasattr(config, "_attn_implementation") or config._attn_implementation is None: - # Auto-detect best available implementation (same as nemo_automodel default) - try: - # Try flash_attention_2 first (fastest) - config._attn_implementation = "flash_attention_2" - except (ImportError, ModuleNotFoundError): - # Fall back to SDPA if available (PyTorch 2.0+) - if hasattr(F, "scaled_dot_product_attention"): - config._attn_implementation = "sdpa" - else: - # Final fallback to eager - config._attn_implementation = "eager" - - if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0: - print(f"[build_llama_model] Attention implementation: {config._attn_implementation}") - print(f"[build_llama_model] torch_dtype: {torch_dtype}") - - # Create backend config with HF state dict adapter enabled - backend = BackendConfig(enable_hf_state_dict_adapter=True) - - # Create model with combined projections - model = LlamaForCausalLM(config=config, backend=backend) - - # need to convert model manually since LlamaForCausalLM does not support to(dtype=...) - model = model.to(dtype=torch_dtype) - - return model - - ModelClass = LlamaForCausalLM diff --git a/tests/unit_tests/_transformers/test_auto_model.py b/tests/unit_tests/_transformers/test_auto_model.py index c48c73ce1..25f520997 100644 --- a/tests/unit_tests/_transformers/test_auto_model.py +++ b/tests/unit_tests/_transformers/test_auto_model.py @@ -103,6 +103,7 @@ def test_from_pretrained_uses_registry_when_available(self): # Prepare a fake custom model class and return value custom_model_instance = Mock() custom_cls = Mock(return_value=custom_model_instance) + custom_cls.__name__ = "MockMockMock" mock_registry.model_arch_name_to_cls = {"CustomArch": custom_cls} returned = NeMoAutoModelForCausalLM.from_pretrained("dummy/path") @@ -130,6 +131,7 @@ def test_from_config_uses_registry_when_available(self): # Registry provides a custom class custom_model_instance = Mock() custom_cls = Mock(return_value=custom_model_instance) + custom_cls.__name__ = "MockMockMock" mock_registry.model_arch_name_to_cls = {"CustomArch": custom_cls} returned = NeMoAutoModelForCausalLM.from_config(cfg) @@ -160,6 +162,7 @@ def test_from_pretrained_registry_downloads_checkpoint_files_rank0(self): # Prepare a fake custom model class and return value custom_model_instance = Mock() custom_cls = Mock(return_value=custom_model_instance) + custom_cls.__name__ = "MockMockMock" mock_registry.model_arch_name_to_cls = {"CustomArch": custom_cls} returned = NeMoAutoModelForCausalLM.from_pretrained("dummy/repo-id") @@ -194,6 +197,7 @@ def test_from_pretrained_registry_downloads_when_dist_uninitialized(self): # Prepare a fake custom model class and return value custom_model_instance = Mock() custom_cls = Mock(return_value=custom_model_instance) + custom_cls.__name__ = "MockMockMock" mock_registry.model_arch_name_to_cls = {"CustomArch": custom_cls} returned = NeMoAutoModelForCausalLM.from_pretrained("dummy/repo-id") @@ -240,7 +244,8 @@ def test_from_config_with_string_calls_autoconfig(self): # Verify AutoConfig.from_pretrained was called with the string mock_autoconfig.assert_called_once_with( "hf-internal-testing/tiny-random-gpt2", - trust_remote_code=False + trust_remote_code=False, + attn_implementation="flash_attention_2", ) # Verify the model was returned assert model is mock_model @@ -539,6 +544,7 @@ def test_packed_sequence_and_cp_overrides_from_pretrained( else: custom_model_instance = Mock() custom_cls = Mock(return_value=custom_model_instance) + custom_cls.__name__ = "MockMockMock" mock_registry.model_arch_name_to_cls = {"CustomArch": custom_cls} mock_hf_loader.return_value = MagicMock(config={}) diff --git a/tests/unit_tests/models/llama/test_llama_custom_model.py b/tests/unit_tests/models/llama/test_llama_custom_model.py index c820aafdd..65e2c712c 100644 --- a/tests/unit_tests/models/llama/test_llama_custom_model.py +++ b/tests/unit_tests/models/llama/test_llama_custom_model.py @@ -21,9 +21,8 @@ import torch from transformers import AutoModelForCausalLM, LlamaConfig -from nemo_automodel.components.models.llama.model import build_llama_model from nemo_automodel.components.models.llama.state_dict_adapter import LlamaStateDictAdapter - +from nemo_automodel import NeMoAutoModelForCausalLM pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -63,20 +62,20 @@ def test_model_matches_hf_with_adapter_bidirectional(self, tiny_llama_checkpoint adapter = LlamaStateDictAdapter(config) # Load HF model - llama_model_hf = ( - AutoModelForCausalLM.from_pretrained( - tiny_llama_checkpoint, attn_implementation="eager", torch_dtype=torch.bfloat16 - ) - .to("cuda") - .to(torch.bfloat16) # need to manual cast to bfloat16 since HF initialize weights in float32 dtype - ) + llama_model_hf = AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path=tiny_llama_checkpoint, + attn_implementation="eager", + torch_dtype=torch.bfloat16, + ).to("cuda") + llama_model_hf.eval() # Build custom model - llama_model_custom = build_llama_model( + llama_model_custom = NeMoAutoModelForCausalLM.from_pretrained( pretrained_model_name_or_path=tiny_llama_checkpoint, attn_implementation="eager", torch_dtype=torch.bfloat16, ).to("cuda") + llama_model_custom.eval() # Verify parameter counts match num_params_hf = sum(p.numel() for p in llama_model_hf.parameters()) @@ -90,13 +89,23 @@ def test_model_matches_hf_with_adapter_bidirectional(self, tiny_llama_checkpoint custom_state_dict_from_hf = adapter.from_hf(hf_state_dict) llama_model_custom.load_state_dict(custom_state_dict_from_hf, strict=True) + s = adapter.to_hf(llama_model_custom.state_dict()) + + for n1, p1 in hf_state_dict.items(): + p2 = s[n1] + assert p1.shape == p2.shape, f"Parameter shape mismatch: {p1.shape} != {p2.shape}" + assert p1.dtype == p2.dtype, f"Parameter dtype mismatch: {p1.dtype} != {p2.dtype}" + assert p1.device == p2.device, f"Parameter device mismatch: {p1.device} != {p2.device}" + assert p1.requires_grad == p2.requires_grad, f"Parameter requires_grad mismatch: {p1.requires_grad} != {p2.requires_grad}" + assert torch.allclose(p1, p2, atol=1e-5, rtol=1e-5), f"Parameter mismatch: {p1} != {p2}" + # Generate test inputs input_ids = torch.randint(0, config.vocab_size, (1, 10)).to("cuda") attention_mask = torch.ones((1, 10)).to("cuda") # Compare HF → Custom outputs with torch.no_grad(): - output_hf = llama_model_hf(input_ids, attention_mask) + output_hf = llama_model_hf(input_ids.clone(), attention_mask.clone()) output_custom = llama_model_custom(input_ids, attention_mask) np.testing.assert_allclose( @@ -112,13 +121,12 @@ def test_model_matches_hf_with_adapter_bidirectional(self, tiny_llama_checkpoint hf_state_dict_from_custom = adapter.to_hf(custom_state_dict) # Create new HF model and load converted state dict - llama_model_hf_converted = ( - AutoModelForCausalLM.from_pretrained( - tiny_llama_checkpoint, attn_implementation="eager", torch_dtype=torch.bfloat16 - ) - .to("cuda") - .to(torch.bfloat16) - ) + llama_model_hf_converted = AutoModelForCausalLM.from_pretrained( + tiny_llama_checkpoint, + attn_implementation="eager", + torch_dtype=torch.bfloat16 + ).to("cuda") + llama_model_hf_converted.eval() llama_model_hf_converted.load_state_dict(hf_state_dict_from_custom, strict=True) # Compare Custom → HF outputs @@ -161,7 +169,7 @@ def test_state_dict_adapter_from_hf_combined_projections(self, tiny_llama_checkp def test_state_dict_adapter_to_hf(self, tiny_llama_checkpoint): """Test converting custom model state dict back to HF format.""" # Build custom model (which uses adapter internally to load from HF checkpoint) - llama_model_custom = build_llama_model( + llama_model_custom = NeMoAutoModelForCausalLM.from_pretrained( pretrained_model_name_or_path=tiny_llama_checkpoint, attn_implementation="eager", torch_dtype=torch.bfloat16, @@ -187,11 +195,12 @@ def test_export_custom_to_hf_checkpoint(self, tiny_llama_checkpoint): export_path = os.path.join(tmpdir, "hf_checkpoint") # Build custom model - llama_model_custom = build_llama_model( + llama_model_custom = NeMoAutoModelForCausalLM.from_pretrained( pretrained_model_name_or_path=tiny_llama_checkpoint, attn_implementation="eager", torch_dtype=torch.bfloat16, ).to("cuda") + llama_model_custom.eval() # Generate test input input_ids = torch.randint(0, config.vocab_size, (1, 10)).to("cuda") @@ -205,15 +214,12 @@ def test_export_custom_to_hf_checkpoint(self, tiny_llama_checkpoint): llama_model_custom.save_pretrained_hf_format(export_path) # Load from saved HF checkpoint - llama_model_hf_loaded = ( - AutoModelForCausalLM.from_pretrained( - export_path, - attn_implementation="eager", - torch_dtype=torch.bfloat16, - ) - .to("cuda") - .to(torch.bfloat16) - ) + llama_model_hf_loaded = AutoModelForCausalLM.from_pretrained( + export_path, + attn_implementation="eager", + torch_dtype=torch.bfloat16, + ).to("cuda") + llama_model_hf_loaded.eval() # Compare outputs with torch.no_grad():