Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 66 additions & 9 deletions auto_round/utils/offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,19 @@ def _load_state_dict_into_module(state_dict: dict, module: torch.nn.Module) -> N
del state_dict


def _clear_module_weights(module: torch.nn.Module, cache_numel: bool = False) -> None:
def _clear_module_weights(
module: torch.nn.Module, cache_numel: bool = False, restorable_params: set | None = None
) -> None:
"""Replace a single module's weight/bias with empty tensors.

Args:
module: The leaf module to clear.
cache_numel: If *True*, store ``_cached_weight_numel`` and
``_cached_weight_shape`` before clearing.
restorable_params: If provided, only clear parameters whose names are
in this set. Parameters not in the set are kept intact because
they cannot be restored from the checkpoint (e.g. dynamically
registered calibration parameters).
"""
if module is None:
return
Expand All @@ -108,6 +114,8 @@ def _clear_module_weights(module: torch.nn.Module, cache_numel: bool = False) ->
for name, param in list(module.named_parameters(recurse=False)):
if param is None or param.numel() == 0:
continue
if restorable_params is not None and name not in restorable_params:
continue
if cache_numel and name == "weight":
module._cached_weight_numel = param.numel()
module._cached_weight_shape = tuple(param.shape)
Expand All @@ -119,6 +127,8 @@ def _clear_module_weights(module: torch.nn.Module, cache_numel: bool = False) ->
for name, buf in list(module.named_buffers(recurse=False)):
if buf is None or buf.numel() == 0:
continue
if restorable_params is not None and name not in restorable_params:
continue
module.register_buffer(name, torch.empty(0, dtype=buf.dtype, device="cpu"))


Expand Down Expand Up @@ -274,6 +284,9 @@ def __init__(
self._tempdir: Optional[str] = None
self._saved: dict[str, dict] = {} # name -> {"save_path": str}

# Cached weight map for clean mode (avoids repeated disk I/O)
self._weight_map: dict[str, str] | None = None

# Hook state (for add_offload_hooks/remove_offload_hooks transparent offloading)
self._hook_handles: list = []
self._model_ref: Optional[torch.nn.Module] = None
Expand Down Expand Up @@ -451,7 +464,7 @@ def _offload(
elif skip_if_saved and name in self._saved:
return
self._save_to_disk(name, module)
self._clear(module)
self._clear(module, block_name=name)

def reload(self, model: torch.nn.Module, names: Union[str, list[str], None] = None) -> None:
"""Reload previously offloaded module(s).
Expand Down Expand Up @@ -501,6 +514,11 @@ def add_offload_hooks(self, model: torch.nn.Module, names: list[str]) -> None:
"""Clear all named modules and register pre-forward hooks for
transparent reload-on-demand.

In ``"clean"`` mode, only parameters that exist in the original
checkpoint are cleared. Dynamically registered parameters (e.g.
calibration scales from attention/KV-cache quantization) are
preserved so they remain valid across reload cycles.

Args:
model: The root model.
names: List of module names to manage.
Expand Down Expand Up @@ -548,7 +566,7 @@ def _pre_forward_hook(self, module: torch.nn.Module, args, *, name: str) -> None
if self._last_loaded is not None and self._last_loaded != name:
prev = get_module(self._model_ref, self._last_loaded)
if prev is not None:
self._clear(prev)
self._clear(prev, block_name=self._last_loaded)
self.reload(self._model_ref, name)
self._last_loaded = name

Expand Down Expand Up @@ -588,7 +606,7 @@ def ensure_loaded(self, model: torch.nn.Module, layer_name: str) -> None:
if self._current_loaded is not None:
module = get_module(model, self._current_loaded)
if module is not None:
self._clear(module)
self._clear(module, block_name=self._current_loaded)
# Load new
self.reload(model, target)
self._current_loaded = target
Expand All @@ -598,7 +616,7 @@ def flush_loaded(self, model: torch.nn.Module) -> None:
if self._current_loaded is not None:
module = get_module(model, self._current_loaded)
if module is not None:
self._clear(module)
self._clear(module, block_name=self._current_loaded)
self._current_loaded = None

# ------------------------------------------------------------------
Expand Down Expand Up @@ -745,10 +763,49 @@ def _cleanup_tempdir(self) -> None:
# Internal: clearing
# ------------------------------------------------------------------

def _clear(self, module: torch.nn.Module) -> None:
"""Clear all weight/bias tensors in *module* and its sub-modules."""
for submodule in module.modules():
_clear_module_weights(submodule, cache_numel=self.cache_numel)
def _get_restorable_params(self, block_name: str) -> set[str] | None:
"""Return the set of parameter names under *block_name* that exist in
the checkpoint, or *None* if the weight map is unavailable.

In ``"clean"`` mode the offload manager reloads weights from the
original checkpoint. Parameters dynamically registered at runtime
(e.g. calibration scales) do **not** appear in the checkpoint, so
clearing them would leave them as empty tensors with no way to
restore. This helper builds the set of parameters that *can* be
restored, allowing ``_clear`` to skip the rest.
"""
if self.mode != "clean" or not self.model_dir:
return None
try:
if self._weight_map is None:
model_dir = _resolve_model_dir(self.model_dir)
self._weight_map = _build_weight_map(model_dir)
weight_map = self._weight_map
except Exception:
return None
prefix = block_name + "."
return {k[len(prefix) :] for k in weight_map if k.startswith(prefix)}

def _clear(self, module: torch.nn.Module, block_name: str | None = None) -> None:
"""Clear weight/bias tensors in module and its sub-modules.

In ``"clean"`` mode with a known block_name, only parameters that
exist in the original checkpoint are cleared. Dynamically registered
parameters (e.g. calibration scales) are preserved.
"""
restorable = self._get_restorable_params(block_name) if block_name else None
for name, submodule in module.named_modules():
if restorable is not None:
# Build the set of param basenames that are restorable for this submodule
sub_prefix = (name + ".") if name else ""
sub_restorable = {
k[len(sub_prefix) :]
for k in restorable
if k.startswith(sub_prefix) and "." not in k[len(sub_prefix) :]
}
else:
sub_restorable = None
_clear_module_weights(submodule, cache_numel=self.cache_numel, restorable_params=sub_restorable)

@staticmethod
def _needs_loading(module: torch.nn.Module) -> bool:
Expand Down
Loading