Skip to content

Commit e46535d

Browse files
committed
Address review feedback for vllm-project#1687: use Module.get_submodule; fix match_named_modules arg order; clean imports
1 parent f5740fe commit e46535d

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from loguru import logger
1515
from pydantic import ConfigDict, PrivateAttr, model_validator
1616
from torch.nn import Module
17-
from operator import attrgetter
1817
from tqdm import tqdm
1918

2019
from llmcompressor.core import Event, EventType, State
@@ -324,14 +323,16 @@ def _set_resolved_mappings(self, model: Module) -> None:
324323
smooth_layer = smooth_layers[smooth_name]
325324

326325
smooth_parent_name = ".".join(smooth_name.split(".")[:-1])
327-
smooth_parent = attrgetter(smooth_parent_name)(model) if smooth_parent_name else model
326+
smooth_parent = (
327+
model if smooth_parent_name == "" else model.get_submodule(smooth_parent_name)
328+
)
328329

329330
balance_layers, balance_names = [], []
330331
for balance_regex in mapping.balance_layers:
331332
# find the submodules that match the activation layer
332333
for balance_suffix, balance_layer in match_named_modules(
333-
balance_regex,
334334
smooth_parent,
335+
balance_regex,
335336
exclude_internal_modules=True,
336337
).items():
337338
balance_name = f"{smooth_parent_name}.{balance_suffix}"
@@ -766,7 +767,7 @@ def get_lowest_common_parent(names: List[str], module: Module) -> Tuple[str, Mod
766767
while True:
767768
if parent_name == "":
768769
return "", module
769-
parent = attrgetter(parent_name)(module)
770+
parent = module.get_submodule(parent_name)
770771
if not isinstance(parent, torch.nn.ModuleList):
771772
return parent_name, parent
772773
parent_name = ".".join(parent_name.split(".")[:-1])

src/llmcompressor/modifiers/distillation/output/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ def on_initialize(self, state: State, **kwargs) -> bool:
6161
else:
6262
model_target, teacher_target = target, target
6363

64-
model_layers = match_named_modules(model_target, state.model)
65-
teacher_layers = match_named_modules(teacher_target, state.teacher_model)
64+
model_layers = match_named_modules(state.model, model_target)
65+
teacher_layers = match_named_modules(state.teacher_model, teacher_target)
6666

6767
if len(model_layers) < 1:
6868
raise ValueError(f"no model layers found for target {target}")

0 commit comments

Comments
 (0)