|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | | -import functools |
15 | | -import importlib |
16 | | -import inspect |
17 | | -import logging |
18 | | - |
19 | | -from torch.utils.data import _utils as torch_data_utils |
20 | | - |
21 | | -# Monkey patch pin_memory to optionally accept a device argument. |
22 | | -# The device argument was removed in some newer torch versions but we |
23 | | -# need it for compatibility with torchdata. |
24 | | -_original_pin_memory_loop = torch_data_utils.pin_memory._pin_memory_loop |
25 | | -_original_pin_memory = torch_data_utils.pin_memory.pin_memory |
26 | | -_original_pin_memory_sig = inspect.signature(_original_pin_memory) |
27 | | - |
28 | | -if "device" not in _original_pin_memory_sig.parameters: |
29 | | - |
30 | | - @functools.wraps(_original_pin_memory) |
31 | | - def _patched_pin_memory(data, device=None): |
32 | | - """Patched pin_memory that accepts an optional device argument.""" |
33 | | - return _original_pin_memory(data) |
34 | | - |
35 | | - @functools.wraps(_original_pin_memory_loop) |
36 | | - def _pin_memory_loop(in_queue, out_queue, device_id, done_event, device): |
37 | | - """Patched _pin_memory_loop to accept a device argument.""" |
38 | | - return _original_pin_memory_loop(in_queue, out_queue, device_id, done_event) |
39 | | - |
40 | | - torch_data_utils.pin_memory.pin_memory = _patched_pin_memory |
41 | | - torch_data_utils.pin_memory._pin_memory_loop = _pin_memory_loop |
42 | | - |
43 | | - |
44 | | -# Monkey patch DeviceMesh to fix corner case in mesh slicing |
45 | | -# Fixes issue where _dim_group_names is accessed without checking if rank is in mesh |
46 | | -# Based on https://github.com/pytorch/pytorch/pull/169454/files |
47 | | -try: |
48 | | - import torch as _torch |
49 | | - |
50 | | - # Only apply the patch for the specific PyTorch version with the regression |
51 | | - # TODO: Remove this once bump up to a newer PyTorch version with the fix |
52 | | - if "2.10.0" in _torch.__version__ and "nv25.11" in _torch.__version__: |
53 | | - from torch.distributed._mesh_layout import _MeshLayout |
54 | | - from torch.distributed.device_mesh import _MeshEnv |
55 | | - |
56 | | - _original_get_slice_mesh_layout = _MeshEnv._get_slice_mesh_layout |
57 | | - |
58 | | - def _patched_get_slice_mesh_layout(self, device_mesh, mesh_dim_names): |
59 | | - """ |
60 | | - Patched _get_slice_mesh_layout based on PyTorch PR #169454. |
61 | | - This fixes: |
62 | | - 1. _dim_group_names access (commit f6c8092) |
63 | | - 2. Regression in mesh slicing with size-1 dims (PR #169454 / Issue #169381) |
64 | | - """ |
65 | | - # 1. First, build the layout manually to bypass the legacy 'stride < pre_stride' check |
66 | | - slice_from_root = device_mesh == self.get_root_mesh(device_mesh) |
67 | | - flatten_name_to_root_layout = ( |
68 | | - {key: mesh._layout for key, mesh in self.root_to_flatten_mapping.setdefault(device_mesh, {}).items()} |
69 | | - if slice_from_root |
70 | | - else {} |
71 | | - ) |
72 | | - |
73 | | - mesh_dim_names_list = getattr(device_mesh, "mesh_dim_names", []) |
74 | | - valid_mesh_dim_names = [*mesh_dim_names_list, *flatten_name_to_root_layout] |
75 | | - if not all(name in valid_mesh_dim_names for name in mesh_dim_names): |
76 | | - raise KeyError(f"Invalid mesh_dim_names {mesh_dim_names}. Valid: {valid_mesh_dim_names}") |
77 | | - |
78 | | - layout_sliced = [] |
79 | | - for name in mesh_dim_names: |
80 | | - if name in mesh_dim_names_list: |
81 | | - layout_sliced.append(device_mesh._layout[mesh_dim_names_list.index(name)]) |
82 | | - elif name in flatten_name_to_root_layout: |
83 | | - layout_sliced.append(flatten_name_to_root_layout[name]) |
84 | | - |
85 | | - sliced_sizes = tuple(layout.sizes for layout in layout_sliced) |
86 | | - sliced_strides = tuple(layout.strides for layout in layout_sliced) |
87 | | - |
88 | | - # Bypass the 'stride < pre_stride' check that exists in the original |
89 | | - # and create the MeshLayout directly. |
90 | | - slice_mesh_layout = _MeshLayout(sliced_sizes, sliced_strides) |
91 | | - |
92 | | - if not slice_mesh_layout.check_non_overlap(): |
93 | | - raise RuntimeError(f"Slicing overlapping dim_names {mesh_dim_names} is not allowed.") |
94 | | - |
95 | | - # 2. Replicate the _dim_group_names fix (commit f6c8092) |
96 | | - # We need to return an object that HAS _dim_group_names if the rank is in the mesh |
97 | | - if hasattr(device_mesh, "_dim_group_names") and len(device_mesh._dim_group_names) > 0: |
98 | | - slice_dim_group_name = [] |
99 | | - submesh_dim_names = mesh_dim_names if isinstance(mesh_dim_names, tuple) else (mesh_dim_names,) |
100 | | - for name in submesh_dim_names: |
101 | | - if name in mesh_dim_names_list: |
102 | | - slice_dim_group_name.append(device_mesh._dim_group_names[mesh_dim_names_list.index(name)]) |
103 | | - elif hasattr(device_mesh, "_flatten_mapping") and name in device_mesh._flatten_mapping: |
104 | | - flatten_mesh = device_mesh._flatten_mapping[name] |
105 | | - slice_dim_group_name.append( |
106 | | - flatten_mesh._dim_group_names[flatten_mesh.mesh_dim_names.index(name)] |
107 | | - ) |
108 | | - |
109 | | - # Attach the group names to the layout object so the caller can use them |
110 | | - object.__setattr__(slice_mesh_layout, "_dim_group_names", slice_dim_group_name) |
111 | | - |
112 | | - return slice_mesh_layout |
113 | | - |
114 | | - # Apply the patch |
115 | | - _MeshEnv._get_slice_mesh_layout = _patched_get_slice_mesh_layout |
116 | | - logging.getLogger(__name__).debug(f"Applied DeviceMesh fix for PyTorch {_torch.__version__}") |
117 | | - |
118 | | -except (ImportError, AttributeError) as e: |
119 | | - logging.getLogger(__name__).debug(f"Could not apply DeviceMesh patch: {e}") |
120 | | - pass |
121 | 14 |
|
| 15 | +import importlib |
122 | 16 |
|
123 | 17 | from .package_info import __package_name__, __version__ |
124 | 18 |
|
125 | | -__all__ = [ |
126 | | - "recipes", |
127 | | - "shared", |
128 | | - "components", |
129 | | - "__version__", |
130 | | - "__package_name__", |
131 | | -] |
| 19 | +# Keep the base package import lightweight. |
| 20 | +# Heavy dependencies (e.g., torch/transformers) are intentionally imported lazily |
| 21 | +# via __getattr__ so importing tokenizers doesn't pull in the full training stack. |
| 22 | + |
| 23 | +_SUBMODULES = {"recipes", "shared", "components"} |
132 | 24 |
|
133 | | -# Promote NeMoAutoModelForCausalLM, AutoModelForImageTextToText into the top level |
134 | | -# to enable: `from nemo_automodel import NeMoAutoModelForCausalLM` |
135 | | -try: |
136 | | - # adjust this import path if your class lives somewhere else |
137 | | - from nemo_automodel._transformers.auto_model import ( |
138 | | - NeMoAutoModelForCausalLM, |
139 | | - NeMoAutoModelForImageTextToText, |
140 | | - NeMoAutoModelForSequenceClassification, |
141 | | - NeMoAutoModelForTextToWaveform, |
142 | | - ) # noqa: I001 |
143 | | - from nemo_automodel._transformers.auto_tokenizer import NeMoAutoTokenizer |
| 25 | +_LAZY_ATTRS: dict[str, tuple[str, str]] = { |
| 26 | + "NeMoAutoModelForCausalLM": ("nemo_automodel._transformers.auto_model", "NeMoAutoModelForCausalLM"), |
| 27 | + "NeMoAutoModelForImageTextToText": ("nemo_automodel._transformers.auto_model", "NeMoAutoModelForImageTextToText"), |
| 28 | + "NeMoAutoModelForSequenceClassification": ( |
| 29 | + "nemo_automodel._transformers.auto_model", |
| 30 | + "NeMoAutoModelForSequenceClassification", |
| 31 | + ), |
| 32 | + "NeMoAutoModelForTextToWaveform": ("nemo_automodel._transformers.auto_model", "NeMoAutoModelForTextToWaveform"), |
| 33 | + "NeMoAutoTokenizer": ("nemo_automodel._transformers.auto_tokenizer", "NeMoAutoTokenizer"), |
| 34 | +} |
144 | 35 |
|
145 | | - globals()["NeMoAutoModelForCausalLM"] = NeMoAutoModelForCausalLM |
146 | | - globals()["NeMoAutoModelForImageTextToText"] = NeMoAutoModelForImageTextToText |
147 | | - globals()["NeMoAutoModelForSequenceClassification"] = NeMoAutoModelForSequenceClassification |
148 | | - globals()["NeMoAutoModelForTextToWaveform"] = NeMoAutoModelForTextToWaveform |
149 | | - globals()["NeMoAutoTokenizer"] = NeMoAutoTokenizer |
150 | | - __all__.append("NeMoAutoModelForCausalLM") |
151 | | - __all__.append("NeMoAutoModelForImageTextToText") |
152 | | - __all__.append("NeMoAutoModelForSequenceClassification") |
153 | | - __all__.append("NeMoAutoModelForTextToWaveform") |
154 | | - __all__.append("NeMoAutoTokenizer") |
155 | | -except: |
156 | | - # optional dependency might be missing, |
157 | | - # leave the name off the module namespace so other imports still work |
158 | | - pass |
| 36 | +__all__ = sorted([*_SUBMODULES, "__version__", "__package_name__", *_LAZY_ATTRS.keys()]) |
159 | 37 |
|
160 | 38 |
|
161 | 39 | def __getattr__(name: str): |
162 | 40 | """ |
163 | | - Lazily import and cache submodules listed in __all__ when accessed. |
| 41 | + Lazily import and cache selected submodules / exported symbols when accessed. |
164 | 42 |
|
165 | 43 | Raises: |
166 | 44 | AttributeError if the name isn’t in __all__. |
167 | 45 | """ |
168 | | - if name in __all__: |
169 | | - # import submodule on first access |
| 46 | + if name in _SUBMODULES: |
170 | 47 | module = importlib.import_module(f"{__name__}.{name}") |
171 | | - # cache it in globals() so future lookups do not re-import |
172 | 48 | globals()[name] = module |
173 | 49 | return module |
| 50 | + if name in _LAZY_ATTRS: |
| 51 | + module_name, attr_name = _LAZY_ATTRS[name] |
| 52 | + module = importlib.import_module(module_name) |
| 53 | + attr = getattr(module, attr_name) |
| 54 | + globals()[name] = attr |
| 55 | + return attr |
174 | 56 | raise AttributeError(f"module {__name__!r} has no attribute {name!r}") |
175 | 57 |
|
176 | 58 |
|
|
0 commit comments