Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3406,8 +3406,8 @@ def _get_quantized_layer_names_outside_blocks(self) -> list:
continue
layer = get_module(self.model, key)
if layer is None:
logger.error(f"could not find layer {key} in the model, exit...")
exit(-1)
logger.warning_once(f"could not find layer {key} in the model, skipping")
continue
if type(layer) in self.supported_types and check_to_quantized(self.layer_config[key]):
layer_names.append(key)

Expand Down
49 changes: 32 additions & 17 deletions auto_round/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1165,38 +1165,53 @@ def _to_model_dtype(model, model_dtype):
return model


def get_module(module, key):
"""Get module from model by key name.
def get_attr(module, key):
"""Get attribute (including parameters like `...weight`) by dotted key.

Args:
module (torch.nn.Module): original model
key (str): module name to be replaced
Missing keys return `None` (legacy behavior relied on by tests).
"""
name_list = key.split(".")
for name in name_list:
if module is None:
return None
module = getattr(module, name, None)
return module


def set_module(model, key, new_module):
"""Set new module into model by key name.
def set_attr(model, key, new_attr):
"""Set attribute (including parameters like `...weight`) by dotted key.

Args:
model (torch.nn.Module): original model
key (str): module name to be replaced
new_module (torch.nn.Module): new module to be inserted
If an intermediate parent doesn't exist, this is a no-op.
"""
module = model
name_list = key.split(".")
for name in name_list[:-1]:
if hasattr(module, name):
module = getattr(module, name)
setattr(module, name_list[-1], new_module)
if not hasattr(module, name):
return
module = getattr(module, name)
setattr(module, name_list[-1], new_attr)


def get_module(module, key):
"""Get module from model by key name using PyTorch native API.

Missing paths return `None` to preserve legacy non-fail-fast behavior.
"""
try:
return module.get_submodule(key)
except (AttributeError, KeyError):
return None

# For getting and setting attribution, such as 'lm_head.weight'
get_attr = get_module
set_attr = set_module

def set_module(model, key, new_module):
"""Set new module into model by key name using PyTorch native API.

Missing paths are ignored (no-op) to preserve legacy behavior.
"""
try:
model.set_submodule(key, new_module)
except (AttributeError, KeyError):
return


def get_layer_features(layer):
Expand Down
Loading