Skip to content

Commit e740de3

Browse files
committed
required changes
1 parent f012ae6 commit e740de3

File tree

4 files changed

+8
-5
lines changed

4 files changed

+8
-5
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def _set_resolved_mappings(self, model: Module) -> None:
324324
smooth_layer = smooth_layers[smooth_name]
325325

326326
smooth_parent_name = ".".join(smooth_name.split(".")[:-1])
327-
smooth_parent = attrgetter(smooth_parent_name)(model)
327+
smooth_parent = attrgetter(smooth_parent_name)(model) if smooth_parent_name else model
328328

329329
balance_layers, balance_names = [], []
330330
for balance_regex in mapping.balance_layers:

src/llmcompressor/modifiers/obcq/sgpt_base.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
)
1818
from compressed_tensors import match_named_modules
1919

20+
def get_prunable_targets():
21+
"""Return the list of prunable layer types."""
22+
return ["Linear", "Conv1d", "Conv2d", "Conv3d", "QATLinear", "QATConv2d", "QATConv3d", "Conv1D"]
2023

2124
class SparsityModifierBase(Modifier):
2225
"""
@@ -147,7 +150,7 @@ def on_start(self, state: State, event: Event, **kwargs):
147150
layer_sparsity = self.sparsity[index]
148151
else:
149152
layer_sparsity = self.sparsity
150-
prunable_targets = ["Linear", "Conv1d", "Conv2d", "Conv3d", "QATLinear", "QATConv2d", "QATConv3d", "Conv1D"]
153+
prunable_targets = get_prunable_targets()
151154
for name, module in match_named_modules(layer, prunable_targets).items():
152155
name = f"{layer_name}.{name}"
153156

@@ -208,7 +211,7 @@ def _infer_owl_layer_sparsity(
208211

209212
groups = {}
210213
for name, layer in layers.items():
211-
prunable_targets = ["Linear", "Conv1d", "Conv2d", "Conv3d", "QATLinear", "QATConv2d", "QATConv3d", "Conv1D"]
214+
prunable_targets = get_prunable_targets()
212215
prunable_layers = match_named_modules(layer, prunable_targets)
213216
z = [
214217
m.weight.abs() * activations[f"{name}.{n}"].unsqueeze(0)

src/llmcompressor/modifiers/smoothquant/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def _resolve_mappings(self, model: Module) -> List[SmoothQuantMapping]:
206206
balance_layers = []
207207
for balance_suffix in to_balance:
208208
# find the submodule that matches the activation layer
209-
_, balance_layer =match_modules_set(
209+
_, balance_layer = match_modules_set(
210210
balance_suffix, layer_name, model
211211
)
212212
if balance_layer:

src/llmcompressor/transformers/compression/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ def infer_sparsity_structure_from_model(model: torch.nn.Module) -> Optional[str]
7676
# check for the common sparsity structures
7777
structures = {"2:4"}
7878
for sparsity_structure in structures:
79-
linear_modules = match_named_modules(model, linear=True)
8079
offloaded_params = get_state_dict_offloaded_model(model)
80+
linear_modules = match_named_modules(model, ["Linear"])
8181

8282
linear_modules_with_sparsity_structure = [
8383
tensor_follows_mask_structure(offloaded_params[f"{name}.weight"])

0 commit comments

Comments
 (0)