1919"""
2020
2121import re as _re
22+ from collections .abc import Mapping , MutableMapping
2223from typing import TYPE_CHECKING , Any , Iterator , Optional
2324
2425if TYPE_CHECKING :
@@ -60,7 +61,7 @@ def get_native_keys_lazy(adapter: Any, hf_state_dict: dict[str, Any]) -> list[st
6061 return out
6162
6263
63- class LazyHFStateDict :
64+ class LazyHFStateDict ( Mapping ) :
6465 """Dict-like wrapper that converts native -> HF on key access (JIT). Reduces peak GPU memory.
6566
6667 Args:
@@ -110,14 +111,23 @@ def __len__(self) -> int:
110111 self ._keys = self ._compute_keys ()
111112 return len (self ._keys )
112113
114+ def values (self ) -> Iterator [Any ]:
115+ for k in self .keys ():
116+ yield self [k ]
117+
113118 def items (self ) -> Iterator [tuple [str , Any ]]:
114119 for k in self .keys ():
115120 yield k , self [k ]
116121
117122
118- class LazyNativeStateDict :
123+ class LazyNativeStateDict ( MutableMapping ) :
119124 """Dict-like wrapper that converts HF -> native on key access (JIT). Merges incrementally to limit peak memory.
120125 When native_backing is set (round-trip from LazyHFStateDict), returns tensors from it directly (zero copy).
126+
127+ Supports in-place mutation (``__setitem__``, ``__delitem__``, ``pop``,
128+ ``update``) via an overlay so that downstream code (e.g. key-renaming
129+ helpers, ``_extra_state`` injection) can treat this object like a plain
130+ ``dict`` without materialising the full lazy mapping.
121131 """
122132
123133 def __init__ (
@@ -131,22 +141,41 @@ def __init__(
131141 self ._adapter = adapter
132142 self ._device_mesh = device_mesh
133143 self ._native_backing = native_backing
134- self ._keys = None
144+ self ._base_keys : Optional [list [str ]] = None
145+ # Overlay for mutation – avoids materialising the full lazy mapping.
146+ self ._overrides : dict [str , Any ] = {}
147+ self ._deleted : set [str ] = set ()
135148 if (
136149 getattr (adapter , "_validate_expert_availability" , None ) is not None
137150 and getattr (adapter , "moe_config" , None ) is not None
138151 ):
139152 adapter ._validate_expert_availability (hf_state_dict , adapter .moe_config .n_routed_experts , device_mesh )
140153
154+ # Internal helpers
155+ def _get_base_keys (self ) -> list [str ]:
156+ if self ._base_keys is None :
157+ self ._base_keys = get_native_keys_lazy (self ._adapter , self ._hf_state_dict )
158+ return self ._base_keys
159+
160+ # Read API
141161 def __iter__ (self ) -> Iterator [str ]:
142162 return self .keys ()
143163
144164 def keys (self ) -> Iterator [str ]:
145- if self ._keys is None :
146- self ._keys = get_native_keys_lazy (self ._adapter , self ._hf_state_dict )
147- return iter (self ._keys )
165+ seen : set [str ] = set ()
166+ for k in self ._get_base_keys ():
167+ if k not in self ._deleted or k in self ._overrides :
168+ seen .add (k )
169+ yield k
170+ for k in self ._overrides :
171+ if k not in seen :
172+ yield k
148173
149174 def __getitem__ (self , native_key : str ) -> Any :
175+ if native_key in self ._overrides :
176+ return self ._overrides [native_key ]
177+ if native_key in self ._deleted :
178+ raise KeyError (native_key )
150179 if self ._native_backing is not None and native_key in self ._native_backing :
151180 return self ._native_backing [native_key ]
152181 get_merged = getattr (self ._adapter , "get_merged_tensor_for_native_key" , None )
@@ -156,15 +185,63 @@ def __getitem__(self, native_key: str) -> Any:
156185 return merged
157186 return self ._hf_state_dict [native_key ]
158187
188+ def get (self , key : str , default : Any = None ) -> Any :
189+ try :
190+ return self [key ]
191+ except KeyError :
192+ return default
193+
159194 def __contains__ (self , key : object ) -> bool :
160- if self ._keys is None :
161- self ._keys = get_native_keys_lazy (self ._adapter , self ._hf_state_dict )
162- return key in self ._keys
195+ if key in self ._overrides :
196+ return True
197+ if key in self ._deleted :
198+ return False
199+ return key in self ._get_base_keys ()
163200
164201 def __len__ (self ) -> int :
165- if self ._keys is None :
166- self ._keys = get_native_keys_lazy (self ._adapter , self ._hf_state_dict )
167- return len (self ._keys )
202+ base = set (self ._get_base_keys ())
203+ return len ((base - self ._deleted ) | set (self ._overrides ))
204+
205+ # Mutation API
206+ def __setitem__ (self , key : str , value : Any ) -> None :
207+ self ._overrides [key ] = value
208+ self ._deleted .discard (key )
209+
210+ def __delitem__ (self , key : str ) -> None :
211+ if key not in self :
212+ raise KeyError (key )
213+ self ._overrides .pop (key , None )
214+ self ._deleted .add (key )
215+
216+ def pop (self , key : str , * default : Any ) -> Any :
217+ try :
218+ value = self [key ]
219+ except KeyError :
220+ if default :
221+ return default [0 ]
222+ raise
223+ self ._overrides .pop (key , None )
224+ self ._deleted .add (key )
225+ return value
226+
227+ def update (self , other : Any = None , ** kwargs : Any ) -> None :
228+ if other is not None :
229+ if hasattr (other , "items" ):
230+ for k , v in other .items ():
231+ self [k ] = v
232+ elif hasattr (other , "keys" ):
233+ for k in other .keys ():
234+ self [k ] = other [k ]
235+ else :
236+ for k , v in other :
237+ self [k ] = v
238+ for k , v in kwargs .items ():
239+ self [k ] = v
240+
241+ # Iteration helpers
242+ def values (self ) -> Iterator [Any ]:
243+ for k in self .keys ():
244+ yield self [k ]
168245
169246 def items (self ) -> Iterator [tuple [str , Any ]]:
170247 for k in self .keys ():
0 commit comments