|
14 | 14 | from loguru import logger
|
15 | 15 | from pydantic import ConfigDict, PrivateAttr, model_validator
|
16 | 16 | from torch.nn import Module
|
17 |
| -from operator import attrgetter |
18 | 17 | from tqdm import tqdm
|
19 | 18 |
|
20 | 19 | from llmcompressor.core import Event, EventType, State
|
@@ -324,14 +323,16 @@ def _set_resolved_mappings(self, model: Module) -> None:
|
324 | 323 | smooth_layer = smooth_layers[smooth_name]
|
325 | 324 |
|
326 | 325 | smooth_parent_name = ".".join(smooth_name.split(".")[:-1])
|
327 |
| - smooth_parent = attrgetter(smooth_parent_name)(model) if smooth_parent_name else model |
| 326 | + smooth_parent = ( |
| 327 | + model if smooth_parent_name == "" else model.get_submodule(smooth_parent_name) |
| 328 | + ) |
328 | 329 |
|
329 | 330 | balance_layers, balance_names = [], []
|
330 | 331 | for balance_regex in mapping.balance_layers:
|
331 | 332 | # find the submodules that match the activation layer
|
332 | 333 | for balance_suffix, balance_layer in match_named_modules(
|
333 |
| - balance_regex, |
334 | 334 | smooth_parent,
|
| 335 | + balance_regex, |
335 | 336 | exclude_internal_modules=True,
|
336 | 337 | ).items():
|
337 | 338 | balance_name = f"{smooth_parent_name}.{balance_suffix}"
|
@@ -766,7 +767,7 @@ def get_lowest_common_parent(names: List[str], module: Module) -> Tuple[str, Mod
|
766 | 767 | while True:
|
767 | 768 | if parent_name == "":
|
768 | 769 | return "", module
|
769 |
| - parent = attrgetter(parent_name)(module) |
| 770 | + parent = module.get_submodule(parent_name) |
770 | 771 | if not isinstance(parent, torch.nn.ModuleList):
|
771 | 772 | return parent_name, parent
|
772 | 773 | parent_name = ".".join(parent_name.split(".")[:-1])
|
0 commit comments