Skip to content

Commit 3394c8c

Browse files
committed
removed get_quantizable,prunable,terminal_layers
1 parent 6303210 commit 3394c8c

File tree

3 files changed

+6
-72
lines changed

3 files changed

+6
-72
lines changed

src/llmcompressor/modifiers/obcq/sgpt_base.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
from llmcompressor.modifiers.utils.hooks import HooksMixin
1515
from llmcompressor.utils.pytorch.module import (
1616
get_no_split_params,
17-
get_prunable_layers,
18-
match_targets,
1917
)
2018
from compressed_tensors import match_named_modules
2119

@@ -149,11 +147,11 @@ def on_start(self, state: State, event: Event, **kwargs):
149147
layer_sparsity = self.sparsity[index]
150148
else:
151149
layer_sparsity = self.sparsity
152-
153-
for name, module in get_prunable_layers(layer).items():
150+
prunable_targets = ["Linear", "Conv1d", "Conv2d", "Conv3d", "QATLinear", "QATConv2d", "QATConv3d", "Conv1D"]
151+
for name, module in match_named_modules(layer, prunable_targets).items():
154152
name = f"{layer_name}.{name}"
155153

156-
if match_targets(name, self.ignore)[0]:
154+
if match_named_modules(name, self.ignore)[0]:
157155
continue
158156

159157
# HACK: previously, embeddings were not quantized because they were not
@@ -210,7 +208,8 @@ def _infer_owl_layer_sparsity(
210208

211209
groups = {}
212210
for name, layer in layers.items():
213-
prunable_layers = get_prunable_layers(layer)
211+
prunable_targets = ["Linear", "Conv1d", "Conv2d", "Conv3d", "QATLinear", "QATConv2d", "QATConv3d", "Conv1D"]
212+
prunable_layers = match_named_modules(layer, prunable_targets)
214213
z = [
215214
m.weight.abs() * activations[f"{name}.{n}"].unsqueeze(0)
216215
for n, m in prunable_layers.items()

src/llmcompressor/modifiers/smoothquant/base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from llmcompressor.utils.fsdp.helpers import get_fsdp_parent
1717
from llmcompressor.utils.pytorch.module import (
1818
get_matching_layer,
19-
match_targets,
2019
)
2120
from compressed_tensors import match_named_modules
2221
MINIMUM_SMOOTHING_SCALE = 1e-5
@@ -205,7 +204,7 @@ def _resolve_mappings(self, model: Module) -> List[SmoothQuantMapping]:
205204
for to_balance, to_smooth in self.mappings:
206205
to_smooth_layers = match_named_modules(to_smooth, model)
207206
for layer_name, smooth_layer in to_smooth_layers.items():
208-
if not match_targets(layer_name, self.ignore)[0]:
207+
if not match_named_modules(layer_name, self.ignore)[0]:
209208
balance_layers = []
210209
for balance_suffix in to_balance:
211210
# find the submodule that matches the activation layer

src/llmcompressor/utils/pytorch/module.py

Lines changed: 0 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,6 @@
4747

4848

4949
__all__ = [
50-
"match_targets",
51-
"get_terminal_layers",
52-
"get_prunable_layers",
53-
"get_quantizable_layers",
5450
"qat_active",
5551
"get_matching_layer",
5652
"get_no_split_params",
@@ -61,21 +57,6 @@
6157
ALL_QUANTIZABLE_TARGET = "__ALL_QUANTIZABLE__"
6258

6359

64-
def match_targets(name: str, targets: Union[str, List[str]]) -> Tuple[bool, int]:
65-
if isinstance(targets, str):
66-
targets = [targets]
67-
68-
for index, target in enumerate(targets):
69-
if target[:3] == "re:":
70-
pattern = target[3:]
71-
if re.match(pattern, name):
72-
return True, index
73-
elif name == target:
74-
return True, index
75-
76-
return False, -1
77-
78-
7960
def match_class(layer: Module, targets: Union[str, List[str]]) -> Tuple[bool, int]:
8061
if isinstance(targets, str):
8162
targets = [targets]
@@ -87,51 +68,6 @@ def match_class(layer: Module, targets: Union[str, List[str]]) -> Tuple[bool, in
8768
return False, -1
8869

8970

90-
def get_terminal_layers(module: Module) -> Dict[str, Module]:
91-
terminal = {}
92-
93-
for name, layer in module.named_modules():
94-
if len(list(layer.named_modules())) > 1:
95-
continue
96-
97-
terminal[name] = layer
98-
99-
return terminal
100-
101-
102-
def get_prunable_layers(module: Module) -> Dict[str, Module]:
103-
prunable = {}
104-
105-
for name, layer in module.named_modules():
106-
if (
107-
isinstance(layer, Linear)
108-
or isinstance(layer, _ConvNd)
109-
or (QATLinear and isinstance(layer, QATLinear))
110-
or (QATConv2d and isinstance(layer, QATConv2d))
111-
or (QATConv3d and isinstance(layer, QATConv3d))
112-
or (TransformerConv1D and isinstance(layer, TransformerConv1D))
113-
):
114-
prunable[name] = layer
115-
116-
return prunable
117-
118-
119-
def get_quantizable_layers(module: Module) -> Dict[str, Module]:
120-
if QATLinear is None:
121-
raise ImportError(
122-
"PyTorch version is not setup for Quantization. "
123-
"Please install a QAT compatible version of PyTorch"
124-
)
125-
126-
quantizable = {}
127-
128-
for name, layer in module.named_modules():
129-
if isinstance(layer, Linear) or isinstance(layer, _ConvNd):
130-
quantizable[name] = layer
131-
132-
return quantizable
133-
134-
13571
def qat_active(module: Module) -> bool:
13672
"""
13773
Determines if any layers in the model have quantization enabled by checking for

0 commit comments

Comments
 (0)