Skip to content

Commit d22ea6a

Browse files
committed
simplify from_pretrained/from_config
Signed-off-by: Alexandros Koumparoulis <[email protected]>
1 parent 4de5d0f commit d22ea6a

File tree

1 file changed

+98
-80
lines changed

1 file changed

+98
-80
lines changed

nemo_automodel/_transformers/auto_model.py

Lines changed: 98 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,6 @@
2323
import torch
2424
import torch.distributed as dist
2525
from torch.nn.attention import SDPBackend, sdpa_kernel
26-
from transformers import (
27-
AutoConfig,
28-
AutoModelForCausalLM,
29-
AutoModelForImageTextToText,
30-
AutoModelForSequenceClassification,
31-
AutoModelForTextToWaveform,
32-
PreTrainedModel,
33-
)
34-
from transformers.modeling_utils import _get_resolved_checkpoint_files
35-
from transformers.models.auto.auto_factory import _BaseAutoModelClass
3626

3727
from nemo_automodel import __version__
3828
from nemo_automodel._transformers.registry import ModelRegistry
@@ -41,9 +31,20 @@
4131
get_local_world_size_preinit,
4232
get_world_size_safe,
4333
)
34+
from nemo_automodel.components.distributed.utils import FirstRankPerNode
4435
from nemo_automodel.components.utils.model_utils import resolve_trust_remote_code
4536
from nemo_automodel.shared.import_utils import safe_import
4637
from nemo_automodel.shared.utils import dtype_from_str
38+
from transformers import (
39+
AutoConfig,
40+
AutoModelForCausalLM,
41+
AutoModelForImageTextToText,
42+
AutoModelForSequenceClassification,
43+
AutoModelForTextToWaveform,
44+
PreTrainedModel,
45+
)
46+
from transformers.modeling_utils import _get_resolved_checkpoint_files
47+
from transformers.models.auto.auto_factory import _BaseAutoModelClass
4748

4849
HAS_LIGER_KERNEL, liger_kernel_trf = safe_import("liger_kernel.transformers")
4950
logger = logging.getLogger(__name__)
@@ -216,6 +217,38 @@ def _verify_sdpa_support(model, is_hf_model, cp_size):
216217
raise ValueError("Model does not support SDPA required for context parallelism")
217218

218219

220+
def _download_model_weights(hf_config, pretrained_model_name_or_path):
221+
if (not dist.is_initialized() or get_local_rank_preinit() == 0) and not os.path.isdir(
222+
pretrained_model_name_or_path
223+
):
224+
num_nodes = (get_world_size_safe() % get_local_world_size_preinit()) + 1 # 1-indexed
225+
if num_nodes > 1:
226+
logging.info(
227+
f"""Downloading model weights on {num_nodes} nodes. This incurs high storage usage.
228+
It is recommended to download once with `hf download` and pass in the downloaded path to the `pretrained_model_name_or_path` argument."""
229+
)
230+
with FirstRankPerNode():
231+
_get_resolved_checkpoint_files(
232+
pretrained_model_name_or_path=pretrained_model_name_or_path,
233+
subfolder="",
234+
variant=None,
235+
gguf_file=None,
236+
from_tf=False,
237+
from_flax=False,
238+
use_safetensors=None,
239+
cache_dir=None,
240+
force_download=False,
241+
proxies=None,
242+
local_files_only=False,
243+
token=None,
244+
user_agent={"file_type": "model", "framework": "pytorch", "from_auto_class": False},
245+
revision="main",
246+
commit_hash=getattr(hf_config, "_commit_hash", None),
247+
is_remote_code=False,
248+
transformers_explicit_filename=None,
249+
)
250+
251+
219252
class _BaseNeMoAutoModelClass(_BaseAutoModelClass):
220253
"""
221254
Drop-in replacement for ``_BaseAutoModelClass`` that includes custom-kernels.
@@ -238,6 +271,24 @@ class _BaseNeMoAutoModelClass(_BaseAutoModelClass):
238271
Liger patch. Unsupported models will silently fall back.
239272
"""
240273

274+
@classmethod
275+
def _from_pretrained_parent_class(cls, *args, **kwargs):
276+
name = cls.__name__
277+
if name.startswith("NeMo"):
278+
cls.__name__ = name[4:]
279+
model = super().from_pretrained(*args, **kwargs)
280+
cls.__name__ = name
281+
return model
282+
283+
@classmethod
284+
def _from_config_parent_class(cls, *args, **kwargs):
285+
name = cls.__name__
286+
if name.startswith("NeMo"):
287+
cls.__name__ = name[4:]
288+
model = super().from_config(*args, **kwargs)
289+
cls.__name__ = name
290+
return model
291+
241292
@classmethod
242293
def from_pretrained(
243294
cls,
@@ -324,66 +375,36 @@ def _retry(**override):
324375
**kwargs,
325376
)
326377

327-
# load model
378+
# 1. if force_hf is True, we will use the parent class to load and return the model as is
379+
if force_hf:
380+
return _BaseNeMoAutoModelClass._from_pretrained_parent_class(
381+
pretrained_model_name_or_path,
382+
*model_args,
383+
torch_dtype=torch_dtype,
384+
attn_implementation=attn_implementation,
385+
**kwargs,
386+
)
387+
388+
# 2. If we have a custom model implementation available, we prioritize that over HF
389+
if hf_config.architectures[0] in ModelRegistry.model_arch_name_to_cls:
390+
# if we are able to init the custom model, we will now download the model weights on local rank 0
391+
_download_model_weights(hf_config, pretrained_model_name_or_path)
392+
logger.info(f"Using custom model implementation for {hf_config.architectures[0]}")
393+
kwargs.pop("trust_remote_code", None)
394+
return ModelRegistry.model_arch_name_to_cls[hf_config.architectures[0]](hf_config, *model_args, **kwargs)
395+
396+
# 3. fallback to parent class
328397
model = None
329398
try:
330-
name = cls.__name__
331-
if name.startswith("NeMo"):
332-
cls.__name__ = name[4:]
333-
if not force_hf:
334-
try:
335-
# if we have a custom model implementation available, we prioritize that over HF
336-
if hf_config.architectures[0] in ModelRegistry.model_arch_name_to_cls:
337-
kwargs.pop("trust_remote_code", None)
338-
model = ModelRegistry.model_arch_name_to_cls[hf_config.architectures[0]](
339-
hf_config, *model_args, **kwargs
340-
)
341-
# if we are able to init the custom model, we will now download the model weights on local rank 0
342-
if (not dist.is_initialized() or get_local_rank_preinit() == 0) and not os.path.isdir(
343-
pretrained_model_name_or_path
344-
):
345-
num_nodes = (get_world_size_safe() % get_local_world_size_preinit()) + 1 # 1-indexed
346-
if num_nodes > 1:
347-
logging.info(
348-
f"""Downloading model weights on {num_nodes} nodes. This incurs high storage usage.
349-
It is recommended to download once with `hf download` and pass in the downloaded path to the `pretrained_model_name_or_path` argument."""
350-
)
351-
_get_resolved_checkpoint_files(
352-
pretrained_model_name_or_path=pretrained_model_name_or_path,
353-
subfolder="",
354-
variant=None,
355-
gguf_file=None,
356-
from_tf=False,
357-
from_flax=False,
358-
use_safetensors=None,
359-
cache_dir=None,
360-
force_download=False,
361-
proxies=None,
362-
local_files_only=False,
363-
token=None,
364-
user_agent={"file_type": "model", "framework": "pytorch", "from_auto_class": False},
365-
revision="main",
366-
commit_hash=getattr(hf_config, "_commit_hash", None),
367-
is_remote_code=False,
368-
transformers_explicit_filename=None,
369-
)
370-
if dist.is_initialized():
371-
dist.barrier()
372-
logger.info(f"Using custom model implementation for {hf_config.architectures[0]}")
373-
return model
374-
except Exception as e:
375-
logger.error(f"Failed to use custom model implementation with error: {e}")
376-
377399
if quantization_config is not None:
378400
kwargs["quantization_config"] = quantization_config
379-
model = super().from_pretrained(
401+
model = _BaseNeMoAutoModelClass._from_pretrained_parent_class(
380402
pretrained_model_name_or_path,
381403
*model_args,
382404
torch_dtype=torch_dtype,
383405
attn_implementation=attn_implementation,
384406
**kwargs,
385407
)
386-
cls.__name__ = name
387408
except ValueError as e:
388409
if "does not support" in str(e):
389410
if model is not None:
@@ -499,35 +520,32 @@ def _retry(**override):
499520
**kwargs,
500521
)
501522

502-
# load model
523+
# 1. if force_hf is True, we will use the parent class to load and return the model as is
524+
if force_hf:
525+
return _BaseNeMoAutoModelClass._from_config_parent_class(
526+
config,
527+
*model_args,
528+
torch_dtype=torch_dtype,
529+
attn_implementation=attn_implementation,
530+
**kwargs,
531+
)
532+
533+
# 2. If we have a custom model implementation available, we prioritize that over HF
534+
if config.architectures[0] in ModelRegistry.model_arch_name_to_cls:
535+
raise NotImplementedError("Custom model implementation is not supported for from_config")
536+
537+
# 3. fallback to parent class
503538
model = None
504539
try:
505-
name = cls.__name__
506-
if name.startswith("NeMo"):
507-
cls.__name__ = name[4:]
508-
if not force_hf:
509-
try:
510-
# if we have a custom model implementation available, we prioritize that over HF
511-
if config.architectures[0] in ModelRegistry.model_arch_name_to_cls:
512-
kwargs.pop("trust_remote_code", None)
513-
model = ModelRegistry.model_arch_name_to_cls[config.architectures[0]](
514-
config, *model_args, **kwargs
515-
)
516-
logger.info(f"Using custom model implementation for {config.architectures[0]}")
517-
return model
518-
except Exception as e:
519-
logger.error(f"Failed to use custom model implementation with error: {e}")
520-
521540
if quantization_config is not None:
522541
kwargs["quantization_config"] = quantization_config
523-
model = super().from_config(
542+
model = _BaseNeMoAutoModelClass._from_config_parent_class(
524543
config,
525544
*model_args,
526-
attn_implementation=attn_implementation,
527545
torch_dtype=torch_dtype,
546+
attn_implementation=attn_implementation,
528547
**kwargs,
529548
)
530-
cls.__name__ = name
531549
except ValueError as e:
532550
if "does not support" in str(e):
533551
logging.warning("Falling back to eager attention.")

0 commit comments

Comments
 (0)