Skip to content

Commit 853b03d

Browse files
committed
dict api
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
1 parent 460fed5 commit 853b03d

File tree

1 file changed

+89
-12
lines changed

1 file changed

+89
-12
lines changed

nemo_automodel/components/models/common/state_dict_lazy.py

Lines changed: 89 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
"""
2020

2121
import re as _re
22+
from collections.abc import Mapping, MutableMapping
2223
from typing import TYPE_CHECKING, Any, Iterator, Optional
2324

2425
if TYPE_CHECKING:
@@ -60,7 +61,7 @@ def get_native_keys_lazy(adapter: Any, hf_state_dict: dict[str, Any]) -> list[st
6061
return out
6162

6263

63-
class LazyHFStateDict:
64+
class LazyHFStateDict(Mapping):
6465
"""Dict-like wrapper that converts native -> HF on key access (JIT). Reduces peak GPU memory.
6566
6667
Args:
@@ -110,14 +111,23 @@ def __len__(self) -> int:
110111
self._keys = self._compute_keys()
111112
return len(self._keys)
112113

114+
def values(self) -> Iterator[Any]:
115+
for k in self.keys():
116+
yield self[k]
117+
113118
def items(self) -> Iterator[tuple[str, Any]]:
114119
for k in self.keys():
115120
yield k, self[k]
116121

117122

118-
class LazyNativeStateDict:
123+
class LazyNativeStateDict(MutableMapping):
119124
"""Dict-like wrapper that converts HF -> native on key access (JIT). Merges incrementally to limit peak memory.
120125
When native_backing is set (round-trip from LazyHFStateDict), returns tensors from it directly (zero copy).
126+
127+
Supports in-place mutation (``__setitem__``, ``__delitem__``, ``pop``,
128+
``update``) via an overlay so that downstream code (e.g. key-renaming
129+
helpers, ``_extra_state`` injection) can treat this object like a plain
130+
``dict`` without materialising the full lazy mapping.
121131
"""
122132

123133
def __init__(
@@ -131,22 +141,41 @@ def __init__(
131141
self._adapter = adapter
132142
self._device_mesh = device_mesh
133143
self._native_backing = native_backing
134-
self._keys = None
144+
self._base_keys: Optional[list[str]] = None
145+
# Overlay for mutation – avoids materialising the full lazy mapping.
146+
self._overrides: dict[str, Any] = {}
147+
self._deleted: set[str] = set()
135148
if (
136149
getattr(adapter, "_validate_expert_availability", None) is not None
137150
and getattr(adapter, "moe_config", None) is not None
138151
):
139152
adapter._validate_expert_availability(hf_state_dict, adapter.moe_config.n_routed_experts, device_mesh)
140153

154+
# Internal helpers
155+
def _get_base_keys(self) -> list[str]:
156+
if self._base_keys is None:
157+
self._base_keys = get_native_keys_lazy(self._adapter, self._hf_state_dict)
158+
return self._base_keys
159+
160+
# Read API
141161
def __iter__(self) -> Iterator[str]:
142162
return self.keys()
143163

144164
def keys(self) -> Iterator[str]:
145-
if self._keys is None:
146-
self._keys = get_native_keys_lazy(self._adapter, self._hf_state_dict)
147-
return iter(self._keys)
165+
seen: set[str] = set()
166+
for k in self._get_base_keys():
167+
if k not in self._deleted or k in self._overrides:
168+
seen.add(k)
169+
yield k
170+
for k in self._overrides:
171+
if k not in seen:
172+
yield k
148173

149174
def __getitem__(self, native_key: str) -> Any:
175+
if native_key in self._overrides:
176+
return self._overrides[native_key]
177+
if native_key in self._deleted:
178+
raise KeyError(native_key)
150179
if self._native_backing is not None and native_key in self._native_backing:
151180
return self._native_backing[native_key]
152181
get_merged = getattr(self._adapter, "get_merged_tensor_for_native_key", None)
@@ -156,15 +185,63 @@ def __getitem__(self, native_key: str) -> Any:
156185
return merged
157186
return self._hf_state_dict[native_key]
158187

188+
def get(self, key: str, default: Any = None) -> Any:
189+
try:
190+
return self[key]
191+
except KeyError:
192+
return default
193+
159194
def __contains__(self, key: object) -> bool:
160-
if self._keys is None:
161-
self._keys = get_native_keys_lazy(self._adapter, self._hf_state_dict)
162-
return key in self._keys
195+
if key in self._overrides:
196+
return True
197+
if key in self._deleted:
198+
return False
199+
return key in self._get_base_keys()
163200

164201
def __len__(self) -> int:
165-
if self._keys is None:
166-
self._keys = get_native_keys_lazy(self._adapter, self._hf_state_dict)
167-
return len(self._keys)
202+
base = set(self._get_base_keys())
203+
return len((base - self._deleted) | set(self._overrides))
204+
205+
# Mutation API
206+
def __setitem__(self, key: str, value: Any) -> None:
207+
self._overrides[key] = value
208+
self._deleted.discard(key)
209+
210+
def __delitem__(self, key: str) -> None:
211+
if key not in self:
212+
raise KeyError(key)
213+
self._overrides.pop(key, None)
214+
self._deleted.add(key)
215+
216+
def pop(self, key: str, *default: Any) -> Any:
217+
try:
218+
value = self[key]
219+
except KeyError:
220+
if default:
221+
return default[0]
222+
raise
223+
self._overrides.pop(key, None)
224+
self._deleted.add(key)
225+
return value
226+
227+
def update(self, other: Any = None, **kwargs: Any) -> None:
228+
if other is not None:
229+
if hasattr(other, "items"):
230+
for k, v in other.items():
231+
self[k] = v
232+
elif hasattr(other, "keys"):
233+
for k in other.keys():
234+
self[k] = other[k]
235+
else:
236+
for k, v in other:
237+
self[k] = v
238+
for k, v in kwargs.items():
239+
self[k] = v
240+
241+
# Iteration helpers
242+
def values(self) -> Iterator[Any]:
243+
for k in self.keys():
244+
yield self[k]
168245

169246
def items(self) -> Iterator[tuple[str, Any]]:
170247
for k in self.keys():

0 commit comments

Comments
 (0)