Skip to content

Commit 63cb6e6

Browse files
committed
get_layer_by_name refactor
1 parent 3b4a0da commit 63cb6e6

File tree

1 file changed

+4
-3
lines changed
  • src/llmcompressor/modifiers/awq

1 file changed

+4
-3
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from loguru import logger
1515
from pydantic import ConfigDict, PrivateAttr, model_validator
1616
from torch.nn import Module
17+
from operator import attrgetter
1718
from tqdm import tqdm
1819

1920
from llmcompressor.core import Event, EventType, State
@@ -29,7 +30,7 @@
2930
from llmcompressor.pipelines.cache import IntermediatesCache
3031
from llmcompressor.utils.fsdp.helpers import get_fsdp_parent
3132
from llmcompressor.utils.helpers import calibration_forward_context
32-
from llmcompressor.utils.pytorch.module import get_layer_by_name, get_layers
33+
from llmcompressor.utils.pytorch.module import get_layers
3334

3435
__all__ = ["AWQModifier"]
3536

@@ -323,7 +324,7 @@ def _set_resolved_mappings(self, model: Module) -> None:
323324
smooth_layer = smooth_layers[smooth_name]
324325

325326
smooth_parent_name = ".".join(smooth_name.split(".")[:-1])
326-
smooth_parent = get_layer_by_name(smooth_parent_name, model)
327+
smooth_parent = attrgetter(smooth_parent_name)(model)
327328

328329
balance_layers, balance_names = [], []
329330
for balance_regex in mapping.balance_layers:
@@ -765,7 +766,7 @@ def get_lowest_common_parent(names: List[str], module: Module) -> Tuple[str, Mod
765766
while True:
766767
if parent_name == "":
767768
return "", module
768-
parent = get_layer_by_name(parent_name, module)
769+
parent = attrgetter(parent_name)(module)
769770
if not isinstance(parent, torch.nn.ModuleList):
770771
return parent_name, parent
771772
parent_name = ".".join(parent_name.split(".")[:-1])

0 commit comments

Comments
 (0)