|
14 | 14 | from loguru import logger
|
15 | 15 | from pydantic import ConfigDict, PrivateAttr, model_validator
|
16 | 16 | from torch.nn import Module
|
| 17 | +from operator import attrgetter |
17 | 18 | from tqdm import tqdm
|
18 | 19 |
|
19 | 20 | from llmcompressor.core import Event, EventType, State
|
|
29 | 30 | from llmcompressor.pipelines.cache import IntermediatesCache
|
30 | 31 | from llmcompressor.utils.fsdp.helpers import get_fsdp_parent
|
31 | 32 | from llmcompressor.utils.helpers import calibration_forward_context
|
32 |
| -from llmcompressor.utils.pytorch.module import get_layer_by_name, get_layers |
| 33 | +from llmcompressor.utils.pytorch.module import get_layers |
33 | 34 |
|
34 | 35 | __all__ = ["AWQModifier"]
|
35 | 36 |
|
@@ -323,7 +324,7 @@ def _set_resolved_mappings(self, model: Module) -> None:
|
323 | 324 | smooth_layer = smooth_layers[smooth_name]
|
324 | 325 |
|
325 | 326 | smooth_parent_name = ".".join(smooth_name.split(".")[:-1])
|
326 |
| - smooth_parent = get_layer_by_name(smooth_parent_name, model) |
| 327 | + smooth_parent = attrgetter(smooth_parent_name)(model) |
327 | 328 |
|
328 | 329 | balance_layers, balance_names = [], []
|
329 | 330 | for balance_regex in mapping.balance_layers:
|
@@ -765,7 +766,7 @@ def get_lowest_common_parent(names: List[str], module: Module) -> Tuple[str, Mod
|
765 | 766 | while True:
|
766 | 767 | if parent_name == "":
|
767 | 768 | return "", module
|
768 |
| - parent = get_layer_by_name(parent_name, module) |
| 769 | + parent = attrgetter(parent_name)(module) |
769 | 770 | if not isinstance(parent, torch.nn.ModuleList):
|
770 | 771 | return parent_name, parent
|
771 | 772 | parent_name = ".".join(parent_name.split(".")[:-1])
|
0 commit comments