Skip to content

Commit 89cbcff

Browse files
authored
feat: improve import time (#1076)
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
1 parent 992a8e0 commit 89cbcff

File tree

7 files changed

+308
-195
lines changed

7 files changed

+308
-195
lines changed

nemo_automodel/__init__.py

Lines changed: 25 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -11,166 +11,48 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import functools
15-
import importlib
16-
import inspect
17-
import logging
18-
19-
from torch.utils.data import _utils as torch_data_utils
20-
21-
# Monkey patch pin_memory to optionally accept a device argument.
22-
# The device argument was removed in some newer torch versions but we
23-
# need it for compatibility with torchdata.
24-
_original_pin_memory_loop = torch_data_utils.pin_memory._pin_memory_loop
25-
_original_pin_memory = torch_data_utils.pin_memory.pin_memory
26-
_original_pin_memory_sig = inspect.signature(_original_pin_memory)
27-
28-
if "device" not in _original_pin_memory_sig.parameters:
29-
30-
@functools.wraps(_original_pin_memory)
31-
def _patched_pin_memory(data, device=None):
32-
"""Patched pin_memory that accepts an optional device argument."""
33-
return _original_pin_memory(data)
34-
35-
@functools.wraps(_original_pin_memory_loop)
36-
def _pin_memory_loop(in_queue, out_queue, device_id, done_event, device):
37-
"""Patched _pin_memory_loop to accept a device argument."""
38-
return _original_pin_memory_loop(in_queue, out_queue, device_id, done_event)
39-
40-
torch_data_utils.pin_memory.pin_memory = _patched_pin_memory
41-
torch_data_utils.pin_memory._pin_memory_loop = _pin_memory_loop
42-
43-
44-
# Monkey patch DeviceMesh to fix corner case in mesh slicing
45-
# Fixes issue where _dim_group_names is accessed without checking if rank is in mesh
46-
# Based on https://github.com/pytorch/pytorch/pull/169454/files
47-
try:
48-
import torch as _torch
49-
50-
# Only apply the patch for the specific PyTorch version with the regression
51-
# TODO: Remove this once bump up to a newer PyTorch version with the fix
52-
if "2.10.0" in _torch.__version__ and "nv25.11" in _torch.__version__:
53-
from torch.distributed._mesh_layout import _MeshLayout
54-
from torch.distributed.device_mesh import _MeshEnv
55-
56-
_original_get_slice_mesh_layout = _MeshEnv._get_slice_mesh_layout
57-
58-
def _patched_get_slice_mesh_layout(self, device_mesh, mesh_dim_names):
59-
"""
60-
Patched _get_slice_mesh_layout based on PyTorch PR #169454.
61-
This fixes:
62-
1. _dim_group_names access (commit f6c8092)
63-
2. Regression in mesh slicing with size-1 dims (PR #169454 / Issue #169381)
64-
"""
65-
# 1. First, build the layout manually to bypass the legacy 'stride < pre_stride' check
66-
slice_from_root = device_mesh == self.get_root_mesh(device_mesh)
67-
flatten_name_to_root_layout = (
68-
{key: mesh._layout for key, mesh in self.root_to_flatten_mapping.setdefault(device_mesh, {}).items()}
69-
if slice_from_root
70-
else {}
71-
)
72-
73-
mesh_dim_names_list = getattr(device_mesh, "mesh_dim_names", [])
74-
valid_mesh_dim_names = [*mesh_dim_names_list, *flatten_name_to_root_layout]
75-
if not all(name in valid_mesh_dim_names for name in mesh_dim_names):
76-
raise KeyError(f"Invalid mesh_dim_names {mesh_dim_names}. Valid: {valid_mesh_dim_names}")
77-
78-
layout_sliced = []
79-
for name in mesh_dim_names:
80-
if name in mesh_dim_names_list:
81-
layout_sliced.append(device_mesh._layout[mesh_dim_names_list.index(name)])
82-
elif name in flatten_name_to_root_layout:
83-
layout_sliced.append(flatten_name_to_root_layout[name])
84-
85-
sliced_sizes = tuple(layout.sizes for layout in layout_sliced)
86-
sliced_strides = tuple(layout.strides for layout in layout_sliced)
87-
88-
# Bypass the 'stride < pre_stride' check that exists in the original
89-
# and create the MeshLayout directly.
90-
slice_mesh_layout = _MeshLayout(sliced_sizes, sliced_strides)
91-
92-
if not slice_mesh_layout.check_non_overlap():
93-
raise RuntimeError(f"Slicing overlapping dim_names {mesh_dim_names} is not allowed.")
94-
95-
# 2. Replicate the _dim_group_names fix (commit f6c8092)
96-
# We need to return an object that HAS _dim_group_names if the rank is in the mesh
97-
if hasattr(device_mesh, "_dim_group_names") and len(device_mesh._dim_group_names) > 0:
98-
slice_dim_group_name = []
99-
submesh_dim_names = mesh_dim_names if isinstance(mesh_dim_names, tuple) else (mesh_dim_names,)
100-
for name in submesh_dim_names:
101-
if name in mesh_dim_names_list:
102-
slice_dim_group_name.append(device_mesh._dim_group_names[mesh_dim_names_list.index(name)])
103-
elif hasattr(device_mesh, "_flatten_mapping") and name in device_mesh._flatten_mapping:
104-
flatten_mesh = device_mesh._flatten_mapping[name]
105-
slice_dim_group_name.append(
106-
flatten_mesh._dim_group_names[flatten_mesh.mesh_dim_names.index(name)]
107-
)
108-
109-
# Attach the group names to the layout object so the caller can use them
110-
object.__setattr__(slice_mesh_layout, "_dim_group_names", slice_dim_group_name)
111-
112-
return slice_mesh_layout
113-
114-
# Apply the patch
115-
_MeshEnv._get_slice_mesh_layout = _patched_get_slice_mesh_layout
116-
logging.getLogger(__name__).debug(f"Applied DeviceMesh fix for PyTorch {_torch.__version__}")
117-
118-
except (ImportError, AttributeError) as e:
119-
logging.getLogger(__name__).debug(f"Could not apply DeviceMesh patch: {e}")
120-
pass
12114

15+
import importlib
12216

12317
from .package_info import __package_name__, __version__
12418

125-
__all__ = [
126-
"recipes",
127-
"shared",
128-
"components",
129-
"__version__",
130-
"__package_name__",
131-
]
19+
# Keep the base package import lightweight.
20+
# Heavy dependencies (e.g., torch/transformers) are intentionally imported lazily
21+
# via __getattr__ so importing tokenizers doesn't pull in the full training stack.
22+
23+
_SUBMODULES = {"recipes", "shared", "components"}
13224

133-
# Promote NeMoAutoModelForCausalLM, AutoModelForImageTextToText into the top level
134-
# to enable: `from nemo_automodel import NeMoAutoModelForCausalLM`
135-
try:
136-
# adjust this import path if your class lives somewhere else
137-
from nemo_automodel._transformers.auto_model import (
138-
NeMoAutoModelForCausalLM,
139-
NeMoAutoModelForImageTextToText,
140-
NeMoAutoModelForSequenceClassification,
141-
NeMoAutoModelForTextToWaveform,
142-
) # noqa: I001
143-
from nemo_automodel._transformers.auto_tokenizer import NeMoAutoTokenizer
25+
_LAZY_ATTRS: dict[str, tuple[str, str]] = {
26+
"NeMoAutoModelForCausalLM": ("nemo_automodel._transformers.auto_model", "NeMoAutoModelForCausalLM"),
27+
"NeMoAutoModelForImageTextToText": ("nemo_automodel._transformers.auto_model", "NeMoAutoModelForImageTextToText"),
28+
"NeMoAutoModelForSequenceClassification": (
29+
"nemo_automodel._transformers.auto_model",
30+
"NeMoAutoModelForSequenceClassification",
31+
),
32+
"NeMoAutoModelForTextToWaveform": ("nemo_automodel._transformers.auto_model", "NeMoAutoModelForTextToWaveform"),
33+
"NeMoAutoTokenizer": ("nemo_automodel._transformers.auto_tokenizer", "NeMoAutoTokenizer"),
34+
}
14435

145-
globals()["NeMoAutoModelForCausalLM"] = NeMoAutoModelForCausalLM
146-
globals()["NeMoAutoModelForImageTextToText"] = NeMoAutoModelForImageTextToText
147-
globals()["NeMoAutoModelForSequenceClassification"] = NeMoAutoModelForSequenceClassification
148-
globals()["NeMoAutoModelForTextToWaveform"] = NeMoAutoModelForTextToWaveform
149-
globals()["NeMoAutoTokenizer"] = NeMoAutoTokenizer
150-
__all__.append("NeMoAutoModelForCausalLM")
151-
__all__.append("NeMoAutoModelForImageTextToText")
152-
__all__.append("NeMoAutoModelForSequenceClassification")
153-
__all__.append("NeMoAutoModelForTextToWaveform")
154-
__all__.append("NeMoAutoTokenizer")
155-
except:
156-
# optional dependency might be missing,
157-
# leave the name off the module namespace so other imports still work
158-
pass
36+
__all__ = sorted([*_SUBMODULES, "__version__", "__package_name__", *_LAZY_ATTRS.keys()])
15937

16038

16139
def __getattr__(name: str):
16240
"""
163-
Lazily import and cache submodules listed in __all__ when accessed.
41+
Lazily import and cache selected submodules / exported symbols when accessed.
16442
16543
Raises:
16644
AttributeError if the name isn’t in __all__.
16745
"""
168-
if name in __all__:
169-
# import submodule on first access
46+
if name in _SUBMODULES:
17047
module = importlib.import_module(f"{__name__}.{name}")
171-
# cache it in globals() so future lookups do not re-import
17248
globals()[name] = module
17349
return module
50+
if name in _LAZY_ATTRS:
51+
module_name, attr_name = _LAZY_ATTRS[name]
52+
module = importlib.import_module(module_name)
53+
attr = getattr(module, attr_name)
54+
globals()[name] = attr
55+
return attr
17456
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
17557

17658

nemo_automodel/_transformers/__init__.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,21 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import importlib
1516

16-
from nemo_automodel._transformers.auto_model import (
17-
NeMoAutoModelForCausalLM,
18-
NeMoAutoModelForImageTextToText,
19-
NeMoAutoModelForSequenceClassification,
20-
NeMoAutoModelForTextToWaveform,
21-
)
22-
from nemo_automodel._transformers.auto_tokenizer import NeMoAutoTokenizer
17+
# Keep this package lightweight: importing `nemo_automodel._transformers.*` should not
18+
# automatically pull in torch + all model code unless a specific symbol is accessed.
19+
20+
_LAZY_ATTRS: dict[str, tuple[str, str]] = {
21+
"NeMoAutoModelForCausalLM": ("nemo_automodel._transformers.auto_model", "NeMoAutoModelForCausalLM"),
22+
"NeMoAutoModelForImageTextToText": ("nemo_automodel._transformers.auto_model", "NeMoAutoModelForImageTextToText"),
23+
"NeMoAutoModelForSequenceClassification": (
24+
"nemo_automodel._transformers.auto_model",
25+
"NeMoAutoModelForSequenceClassification",
26+
),
27+
"NeMoAutoModelForTextToWaveform": ("nemo_automodel._transformers.auto_model", "NeMoAutoModelForTextToWaveform"),
28+
"NeMoAutoTokenizer": ("nemo_automodel._transformers.auto_tokenizer", "NeMoAutoTokenizer"),
29+
}
2330

2431
__all__ = [
2532
"NeMoAutoModelForCausalLM",
@@ -28,3 +35,17 @@
2835
"NeMoAutoModelForTextToWaveform",
2936
"NeMoAutoTokenizer",
3037
]
38+
39+
40+
def __getattr__(name: str):
41+
if name in _LAZY_ATTRS:
42+
module_name, attr_name = _LAZY_ATTRS[name]
43+
module = importlib.import_module(module_name)
44+
attr = getattr(module, attr_name)
45+
globals()[name] = attr
46+
return attr
47+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
48+
49+
50+
def __dir__():
51+
return sorted(__all__)

nemo_automodel/_transformers/auto_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@
2323

2424
import torch
2525
from torch.nn.attention import SDPBackend, sdpa_kernel
26+
27+
from nemo_automodel.shared.torch_patches import apply_torch_patches
28+
29+
apply_torch_patches()
2630
from transformers import (
2731
AutoConfig,
2832
AutoModelForCausalLM,

nemo_automodel/_transformers/auto_tokenizer.py

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,6 @@
1515
import logging
1616
from typing import Callable, Optional, Type, Union
1717

18-
from transformers import AutoConfig, AutoTokenizer
19-
20-
from nemo_automodel._transformers.tokenization.nemo_auto_tokenizer import NeMoAutoTokenizerWithBosEosEnforced
21-
from nemo_automodel._transformers.tokenization.registry import TokenizerRegistry
22-
2318
logger = logging.getLogger(__name__)
2419

2520

@@ -35,14 +30,24 @@ def _get_model_type(pretrained_model_name_or_path: str, trust_remote_code: bool
3530
The model_type string, or None if it cannot be determined
3631
"""
3732
try:
33+
from transformers import AutoConfig
34+
3835
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code)
3936
return getattr(config, "model_type", None)
4037
except Exception as e:
4138
logger.debug(f"Could not load config to determine model type: {e}")
4239
return None
4340

4441

45-
class NeMoAutoTokenizer(AutoTokenizer):
42+
def _get_tokenizer_registry():
43+
# Import lazily to avoid pulling in optional/custom backends (and transformers)
44+
# when users only do `from nemo_automodel import NeMoAutoTokenizer`.
45+
from nemo_automodel._transformers.tokenization.registry import TokenizerRegistry
46+
47+
return TokenizerRegistry
48+
49+
50+
class NeMoAutoTokenizer:
4651
"""
4752
Auto tokenizer class that dispatches to appropriate tokenizer implementations.
4853
@@ -62,13 +67,7 @@ class NeMoAutoTokenizer(AutoTokenizer):
6267
"""
6368

6469
# Make registry accessible at class level
65-
_registry = TokenizerRegistry
66-
67-
def __init__(self):
68-
raise EnvironmentError(
69-
f"{self.__class__.__name__} is designed to be instantiated using the "
70-
f"`{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` method."
71-
)
70+
_registry = None
7271

7372
@classmethod
7473
def register(cls, model_type: str, tokenizer_cls: Union[Type, Callable]) -> None:
@@ -79,7 +78,7 @@ def register(cls, model_type: str, tokenizer_cls: Union[Type, Callable]) -> None
7978
model_type: The model type string (e.g., "mistral", "llama")
8079
tokenizer_cls: The tokenizer class or factory function
8180
"""
82-
cls._registry.register(model_type, tokenizer_cls)
81+
_get_tokenizer_registry().register(model_type, tokenizer_cls)
8382

8483
@classmethod
8584
def from_pretrained(
@@ -106,19 +105,26 @@ def from_pretrained(
106105
"""
107106
# If force_hf, just use the base HF AutoTokenizer
108107
if force_hf:
109-
return super().from_pretrained(
108+
from transformers import AutoTokenizer
109+
110+
return AutoTokenizer.from_pretrained(
110111
pretrained_model_name_or_path, *args, trust_remote_code=trust_remote_code, **kwargs
111112
)
112113

113114
# Try to determine model type from config
114115
model_type = _get_model_type(pretrained_model_name_or_path, trust_remote_code=trust_remote_code)
115116

116-
if model_type and cls._registry.has_custom_tokenizer(model_type):
117-
tokenizer_cls = cls._registry.get_tokenizer_cls(model_type)
118-
logger.info(f"Using custom tokenizer {tokenizer_cls.__name__} for model type '{model_type}'")
119-
return tokenizer_cls.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
117+
registry = _get_tokenizer_registry()
118+
119+
if not force_default and model_type:
120+
tokenizer_cls = registry.get_custom_tokenizer_cls(model_type)
121+
if tokenizer_cls is not None:
122+
logger.info(f"Using custom tokenizer {tokenizer_cls.__name__} for model type '{model_type}'")
123+
return tokenizer_cls.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
120124

121125
# Fall back to default BOS/EOS enforced tokenizer
126+
from nemo_automodel._transformers.tokenization.nemo_auto_tokenizer import NeMoAutoTokenizerWithBosEosEnforced
127+
122128
return NeMoAutoTokenizerWithBosEosEnforced.from_pretrained(
123129
pretrained_model_name_or_path, *args, trust_remote_code=trust_remote_code, **kwargs
124130
)
@@ -129,3 +135,17 @@ def from_pretrained(
129135
"NeMoAutoTokenizerWithBosEosEnforced",
130136
"TokenizerRegistry",
131137
]
138+
139+
140+
def __getattr__(name: str):
141+
if name == "TokenizerRegistry":
142+
return _get_tokenizer_registry()
143+
if name == "NeMoAutoTokenizerWithBosEosEnforced":
144+
from nemo_automodel._transformers.tokenization.nemo_auto_tokenizer import NeMoAutoTokenizerWithBosEosEnforced
145+
146+
return NeMoAutoTokenizerWithBosEosEnforced
147+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
148+
149+
150+
def __dir__():
151+
return sorted(__all__)

0 commit comments

Comments
 (0)