Skip to content

Commit 6ddd036

Browse files
ZewenShen-CoheredsikkaHDCharles
authored
[AWQ] Add activation_hook_target field for custom activation cache hooking (vllm-project#2346)
## Summary - Adds an optional `activation_hook_target` field to `AWQMapping` that lets users specify which submodule (relative to the parent/LCA) to hook for activation caching, replacing the hardcoded `hasattr(parent, 'mlp')` workaround for MoE models with parallel transformer blocks. - When `activation_hook_target` is `None` (default), behavior is unchanged: the hook is placed on `balance_layers[0]`. When set (e.g. `"mlp"`), it resolves to the corresponding submodule on the parent via `getattr_chain`. ## Motivation In parallel transformer architectures, attention and MLP run in parallel from the same input. The existing code always hooks `balance_layers[0]` for activation caching, which captures the wrong activations when balance layers span both attention and MLP branches. There was a commented-out `hasattr(parent, 'mlp')` workaround, but it was brittle and not generalizable. This change makes the hook target explicitly configurable per mapping. ## Test I've tested this change with our internal models, and it aligns with previous results. --------- Co-authored-by: Dipika Sikka <dipikasikka1@gmail.com> Co-authored-by: HDCharles <39544797+HDCharles@users.noreply.github.com>
1 parent b0463d1 commit 6ddd036

File tree

2 files changed

+68
-9
lines changed

2 files changed

+68
-9
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@ class AWQModifier(Modifier, QuantizationMixin):
7676
balance_layers: ["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"]
7777
- smooth_layer: "re:.*final_layer_norm"
7878
balance_layers: ["re:.*fc1"]
79+
# activation_hook_target specifies which submodule of the parent to hook
80+
# for activation caching.
81+
# This change is only useful for MoE models with parallel transformer blocks,
82+
# and one should use the default value (None) in most cases.
7983
ignore: ["lm_head"]
8084
config_groups:
8185
group_0:
@@ -122,6 +126,11 @@ class AWQModifier(Modifier, QuantizationMixin):
122126
to smoothed) and the second entry is the layer whose output is scaled to
123127
achieve the smoothing.
124128
If regex is used, it matches layers with the largest overlap in module name.
129+
Each mapping may also include an ``activation_hook_target``: a dotted
130+
attribute path relative to the parent module (lowest common ancestor)
131+
specifying which submodule to hook for activation caching. This is useful
132+
for parallel transformer blocks where the default (hooking
133+
``balance_layers[0]``) would capture the wrong activations.
125134
:param ignore: list of layers to ignore during quantization (not smoothed).
126135
It should match the name of layers whose outputs are scaled to achieve
127136
smoothing (the second entry of the mappings list).
@@ -389,6 +398,17 @@ def _set_resolved_mappings(self, model: Module) -> None:
389398
balance_names, model, torch.nn.ModuleList
390399
)
391400

401+
activation_hook_target = None
402+
if mapping.activation_hook_target:
403+
activation_hook_target = getattr_chain(
404+
ancestor, mapping.activation_hook_target
405+
)
406+
if activation_hook_target is None:
407+
raise ValueError(
408+
f"activation_hook_target '{mapping.activation_hook_target}'"
409+
f" not found on parent module '{ancestor_name}'"
410+
)
411+
392412
resolved_mappings.append(
393413
ResolvedMapping(
394414
smooth_name,
@@ -397,6 +417,7 @@ def _set_resolved_mappings(self, model: Module) -> None:
397417
balance_names=balance_names,
398418
parent=ancestor,
399419
parent_name=ancestor_name,
420+
activation_hook_target=activation_hook_target,
400421
)
401422
)
402423
self._resolved_mappings = resolved_mappings
@@ -468,16 +489,14 @@ def cache_smooth_activations_hook(
468489
# input activations to balance layers needed for loss function
469490
# storing inputs to first balance layer is sufficient
470491
# other balance layers get the same input
471-
472-
# The line below is useful for models that use parallel transformer block,
473-
# such as gemma 3, command A. Need a better way to integrate it to the code.
474-
# layer_to_hook = (
475-
# mapping.parent.mlp
476-
# if hasattr(mapping.parent, 'mlp')
477-
# else mapping.balance_layers[0]
478-
# )
492+
#
493+
# For parallel transformer blocks (e.g. Command A, Gemma 3) the first
494+
# balance layer may not receive the right activations. When
495+
# activation_hook_target is set on the mapping, hook that module
496+
# instead of balance_layers[0].
497+
layer_to_hook = mapping.activation_hook_target or mapping.balance_layers[0]
479498
self.register_hook(
480-
mapping.balance_layers[0],
499+
layer_to_hook,
481500
create_cache_smooth_activations_hook_fn(mapping.smooth_name),
482501
"forward",
483502
)

src/llmcompressor/modifiers/awq/mappings.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,21 @@ class AWQMapping:
1616
`AWQMapping`s are resolved into `ResolvedMapping`s, which
1717
retain pointers to the actual `torch.nn.Module`s and additional
1818
metadata at runtime
19+
20+
:param smooth_layer: regex or name of the activation layer to smooth
21+
:param balance_layers: list of regex or names of weight layers that must
22+
be balanced to offset the smoothing
23+
:param activation_hook_target: optional dotted attribute path relative to the
24+
parent module (lowest common ancestor of balance_layers) specifying which
25+
submodule to hook for activation caching. Useful for parallel transformer
26+
blocks (e.g. Cohere, Gemma 3) where the first balance layer is not the
27+
correct place to capture activations. When ``None`` (default), the hook
28+
is placed on ``balance_layers[0]``.
1929
"""
2030

2131
smooth_layer: str
2232
balance_layers: list[str]
33+
activation_hook_target: str | None = None
2334

2435

2536
_default_mappings = [
@@ -181,6 +192,30 @@ class AWQMapping:
181192
),
182193
]
183194

195+
# Example mapping for MoE models with parallel transformer blocks, where
196+
# attention and MoE share the same input. This is the only case where
197+
# activation_hook_target is needed. Without it, the hook lands on
198+
# balance_layers[0] — which could be a single expert — capturing only that expert's
199+
# input rather than the full activation flowing into the MLP & Attention branch.
200+
# Setting activation_hook_target="mlp" hooks parent.mlp instead, so the cached
201+
# activations reflect the complete input to the MoE & Attention branch.
202+
_example_parallel_transformer_block_mappings = [
203+
AWQMapping(
204+
"re:.*input_layernorm$",
205+
[
206+
"re:.*mlp.experts.[0-9]+.gate_proj$",
207+
"re:.*mlp.experts.[0-9]+.up_proj$",
208+
"re:.*mlp.shared_experts.gate_proj$",
209+
"re:.*mlp.shared_experts.up_proj$",
210+
"re:.*mlp.gate$",
211+
"re:.*q_proj$",
212+
"re:.*k_proj$",
213+
"re:.*v_proj$",
214+
],
215+
activation_hook_target="mlp",
216+
)
217+
]
218+
184219
AWQ_MAPPING_REGISTRY: dict[str, list[AWQMapping]] = {
185220
"BloomForCausalLM": _bloom_mappings,
186221
"CohereForCausalLM": _cohere_mappings,
@@ -223,6 +258,10 @@ class ResolvedMapping:
223258
:param balance_names: optional list of names of the balance_layers
224259
:param parent: parent module of the balance_layers
225260
:param parent_name: name of the parent module
261+
:param activation_hook_target: optional resolved module to hook for activation
262+
caching. When set, the activation cache hook is placed on this module
263+
instead of ``balance_layers[0]``. Populated from
264+
``AWQMapping.activation_hook_target``.
226265
"""
227266

228267
smooth_name: str
@@ -231,6 +270,7 @@ class ResolvedMapping:
231270
balance_names: list[str]
232271
parent: Module
233272
parent_name: str
273+
activation_hook_target: Module | None = None
234274

235275

236276
def get_layer_mappings_from_architecture(architecture: str) -> list[AWQMapping]:

0 commit comments

Comments
 (0)