Skip to content

Commit 3d49764

Browse files
authored
[Utils] Skip internal modules when matching (#404)
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 5478b43 commit 3d49764

File tree

6 files changed

+105
-53
lines changed

6 files changed

+105
-53
lines changed

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,17 +112,21 @@ def dequantize(
112112
if scale.shape[1] == 1:
113113
args = QuantizationArgs(strategy=QuantizationStrategy.CHANNEL)
114114
# Scale height matches input or is 1 -> group quantization across columns
115-
#
115+
#
116116
# Example 1: scale.shape[0] == 1
117117
# x_q: (4, 8), scale: (1, 4) -> 2 columns per group
118118
#
119-
# Example 2: scale.shape[0] == x_q.shape[0]
119+
# Example 2: scale.shape[0] == x_q.shape[0]
120120
# x_q: (4, 8), scale: (4, 4) -> 2 elements per group (per row)
121121
elif (scale.shape[0] == 1) or (scale.shape[0] == x_q.shape[0]):
122122
group_size = int(x_q.shape[1] / scale.shape[1])
123-
args = QuantizationArgs(strategy=QuantizationStrategy.GROUP, group_size=group_size)
123+
args = QuantizationArgs(
124+
strategy=QuantizationStrategy.GROUP, group_size=group_size
125+
)
124126
else:
125-
args = QuantizationArgs(strategy=QuantizationStrategy.BLOCK, block_structure=scale.shape)
127+
args = QuantizationArgs(
128+
strategy=QuantizationStrategy.BLOCK, block_structure=scale.shape
129+
)
126130
else:
127131
raise ValueError(
128132
f"Could not infer a quantization strategy from scale with {scale.ndim} "

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -185,27 +185,29 @@ def _initialize_scale_zero_point(
185185
elif quantization_args.strategy == QuantizationStrategy.BLOCK:
186186
# For block quantization, scale shape should match number of blocks - only for weights
187187
if quantization_args.block_structure is None:
188-
raise ValueError("Block quantization requires block_structure to be specified")
188+
raise ValueError(
189+
"Block quantization requires block_structure to be specified"
190+
)
189191
block_height, block_width = quantization_args.block_structure
190192
rows, cols = weight_shape[-2], weight_shape[-1]
191193
num_rows_blocks = math.ceil(rows / block_height)
192194
num_cols_blocks = math.ceil(cols / block_width)
193-
195+
194196
# Warn if dimensions don't divide evenly
195197
if rows % block_height != 0 or cols % block_width != 0:
196198
warnings.warn(
197199
f"Block quantization: tensor shape {weight_shape} does not divide evenly "
198200
f"by block structure {quantization_args.block_structure}. "
199201
f"Some blocks will be incomplete which may affect quantization quality.",
200-
UserWarning
202+
UserWarning,
201203
)
202-
204+
203205
expected_shape = (num_rows_blocks, num_cols_blocks)
204206
elif quantization_args.strategy == QuantizationStrategy.BLOCK:
205207
warnings.warn(
206208
f"BLOCK quantization not supported for {base_name} activations. "
207209
f"Falling back to tensor-level quantization.",
208-
UserWarning
210+
UserWarning,
209211
)
210212
expected_shape = 1
211213

src/compressed_tensors/quantization/quant_scheme.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,9 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme":
6464
raise ValueError("Cannot apply actorder to output activations")
6565

6666
if (
67-
inputs and weights
68-
and weights.strategy == QuantizationStrategy.GROUP
67+
inputs
68+
and weights
69+
and weights.strategy == QuantizationStrategy.GROUP
6970
and inputs.strategy == QuantizationStrategy.GROUP
7071
and weights.group_size != inputs.group_size
7172
):
@@ -75,7 +76,7 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme":
7576
"may complicate fused kernel implementations. Consider using "
7677
"TENSOR_GROUP strategy for both or matching group sizes.",
7778
UserWarning,
78-
stacklevel=2
79+
stacklevel=2,
7980
)
8081

8182
return model

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import math
1516
from typing import Optional, Union
1617

17-
import math
1818
import torch
1919
from compressed_tensors.transform import TransformArgs, TransformScheme
2020
from compressed_tensors.transform.factory.base import TransformBase, TransformFactory
@@ -103,7 +103,8 @@ def forward(self, value: Tensor) -> Tensor:
103103

104104
if self.args.inverse:
105105
weight = weight.T
106-
107-
return apply_transform_weight(
108-
weight, value, self.args.location, self.module_type
109-
) / self._scale
106+
107+
return (
108+
apply_transform_weight(weight, value, self.args.location, self.module_type)
109+
/ self._scale
110+
)

src/compressed_tensors/utils/match.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from typing import Iterable, Tuple
1919

2020
import torch
21+
from compressed_tensors.utils.internal import InternalModule
2122

2223

2324
_LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -28,8 +29,6 @@
2829
"match_named_parameters",
2930
"match_modules_set",
3031
"is_match",
31-
"match_name",
32-
"match_class",
3332
]
3433

3534

@@ -83,13 +82,16 @@ def match_named_parameters(
8382
"""
8483
unmatched_targets = set(targets)
8584
for module_name, module in model.named_modules():
85+
if isinstance(module, InternalModule):
86+
continue
87+
8688
for param_name, param in module.named_parameters(recurse=False):
8789
param_fqn = f"{module_name}.{param_name}"
8890
for target in targets:
89-
if match_name(param_fqn, target):
91+
if _match_name(param_fqn, target):
9092
unmatched_targets -= {target}
9193

92-
if not any(match_name(param_fqn, ign) for ign in ignore):
94+
if not any(_match_name(param_fqn, ign) for ign in ignore):
9395
yield param_fqn, module, param
9496

9597
if warn_on_fail:
@@ -165,11 +167,14 @@ def match_modules_set(
165167
def is_match(name: str, module: torch.nn.Module, target: str) -> bool:
166168
"""
167169
Returns true if either module name or module parent classes match against target
170+
and the module is not an internal module
168171
"""
169-
return match_name(name, target) or match_class(module, target)
172+
return not isinstance(module, InternalModule) and (
173+
_match_name(name, target) or _match_class(module, target)
174+
)
170175

171176

172-
def match_name(name: str, target: str) -> bool:
177+
def _match_name(name: str, target: str) -> bool:
173178
"""
174179
Returns true if target string begins with "re:" and
175180
regex matches or if target string exactly matches name
@@ -180,7 +185,7 @@ def match_name(name: str, target: str) -> bool:
180185
return target == name
181186

182187

183-
def match_class(module: torch.nn.Module, target: str) -> bool:
188+
def _match_class(module: torch.nn.Module, target: str) -> bool:
184189
"""
185190
Returns true if any torch parent class names match the target string exactly
186191
"""

tests/test_utils/test_match.py

Lines changed: 68 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@
1919
from accelerate import init_empty_weights
2020

2121
# Assuming the module is named "module_matching" - adjust import as needed
22-
from compressed_tensors.utils.match import (
22+
from compressed_tensors.utils import (
23+
InternalModule,
2324
is_match,
24-
match_class,
2525
match_modules_set,
26-
match_name,
2726
match_named_modules,
2827
match_named_parameters,
2928
)
29+
from compressed_tensors.utils.match import _match_class, _match_name
3030

3131

3232
class DummyModel(nn.Module):
@@ -66,14 +66,14 @@ def __init__(self):
6666

6767

6868
class TestMatchName:
69-
"""Test cases for match_name function"""
69+
"""Test cases for _match_name function"""
7070

7171
def test_exact_match(self):
7272
"""Test exact string matching"""
73-
assert match_name("layer1", "layer1") == True
74-
assert match_name("layer1", "layer2") == False
73+
assert _match_name("layer1", "layer1") == True
74+
assert _match_name("layer1", "layer2") == False
7575
assert (
76-
match_name(
76+
_match_name(
7777
"transformer.layers.0.self_attn.q_proj",
7878
"transformer.layers.0.self_attn.q_proj",
7979
)
@@ -82,14 +82,14 @@ def test_exact_match(self):
8282

8383
def test_regex_match(self):
8484
"""Test regex matching with "re:" prefix"""
85-
assert match_name("layer1", "re:layer.*") == True
86-
assert match_name("layer1", "re:^layer1$") == True
87-
assert match_name("layer1", "re:layer2") == False
85+
assert _match_name("layer1", "re:layer.*") == True
86+
assert _match_name("layer1", "re:^layer1$") == True
87+
assert _match_name("layer1", "re:layer2") == False
8888
assert (
89-
match_name("transformer.layers.0.self_attn.q_proj", "re:.*q_proj") == True
89+
_match_name("transformer.layers.0.self_attn.q_proj", "re:.*q_proj") == True
9090
)
9191
assert (
92-
match_name(
92+
_match_name(
9393
"transformer.layers.0.self_attn.q_proj",
9494
"re:transformer\\.layers\\.\\d+\\.self_attn\\..*_proj$",
9595
)
@@ -98,49 +98,49 @@ def test_regex_match(self):
9898

9999
def test_empty_strings(self):
100100
"""Test edge cases with empty strings"""
101-
assert match_name("", "") == True
102-
assert match_name("layer1", "") == False
103-
assert match_name("", "layer1") == False
101+
assert _match_name("", "") == True
102+
assert _match_name("layer1", "") == False
103+
assert _match_name("", "layer1") == False
104104

105105
def test_regex_special_characters(self):
106106
"""Test regex with special characters"""
107-
assert match_name("layer.1", "re:layer\\.1") == True
108-
assert match_name("layer.1", "re:layer.1") == True # . matches any char
109-
assert match_name("layer_1", "re:layer_1") == True
107+
assert _match_name("layer.1", "re:layer\\.1") == True
108+
assert _match_name("layer.1", "re:layer.1") == True # . matches any char
109+
assert _match_name("layer_1", "re:layer_1") == True
110110

111111

112112
class TestMatchClass:
113-
"""Test cases for match_class function"""
113+
"""Test cases for _match_class function"""
114114

115115
def test_direct_class_match(self):
116116
"""Test matching direct class names"""
117117
linear = nn.Linear(10, 20)
118-
assert match_class(linear, "Linear") == True
119-
assert match_class(linear, "Conv2d") == False
118+
assert _match_class(linear, "Linear") == True
119+
assert _match_class(linear, "Conv2d") == False
120120

121121
norm = nn.LayerNorm(10)
122-
assert match_class(norm, "LayerNorm") == True
123-
assert match_class(norm, "BatchNorm1d") == False
122+
assert _match_class(norm, "LayerNorm") == True
123+
assert _match_class(norm, "BatchNorm1d") == False
124124

125125
def test_parent_class_match(self):
126126
"""Test matching parent class names"""
127127
linear = nn.Linear(10, 20)
128-
assert match_class(linear, "Module") == True
128+
assert _match_class(linear, "Module") == True
129129

130130
conv = nn.Conv2d(3, 16, 3)
131-
assert match_class(conv, "Module") == True
132-
assert match_class(conv, "_ConvNd") == True
131+
assert _match_class(conv, "Module") == True
132+
assert _match_class(conv, "_ConvNd") == True
133133

134134
def test_non_torch_module(self):
135135
"""Test with non-torch modules"""
136136
regular_object = object()
137-
assert match_class(regular_object, "object") == False # not a torch.nn.Module
137+
assert _match_class(regular_object, "object") == False # not a torch.nn.Module
138138

139139
def test_custom_module(self):
140140
"""Test with custom module classes"""
141141
model = DummyModel()
142-
assert match_class(model, "DummyModel") == True
143-
assert match_class(model, "Module") == True
142+
assert _match_class(model, "DummyModel") == True
143+
assert _match_class(model, "Module") == True
144144

145145

146146
class TestIsMatch:
@@ -171,6 +171,15 @@ def test_regex_in_name_match(self):
171171
assert is_match("layer1", linear, "re:layer.*") == True
172172
assert is_match("layer1", linear, "re:conv.*") == False
173173

174+
def test_internal_module_match(self):
175+
"""Test not matching internal modules"""
176+
177+
class InternalLinear(InternalModule, nn.Linear):
178+
pass
179+
180+
linear = InternalLinear(10, 20)
181+
assert is_match("layer1", linear, "re:layer.*") == False
182+
174183

175184
class TestMatchNamedModules:
176185
"""Test cases for match_named_modules function"""
@@ -236,6 +245,16 @@ def test_warn_on_fail(self, mock_logger):
236245
assert "Could not match" in warning_msg
237246
assert "nonexistent_module" in warning_msg
238247

248+
def test_internal_match(self):
249+
"""Test not matching internal modules"""
250+
251+
class InternalLinear(InternalModule, nn.Linear):
252+
pass
253+
254+
linear = InternalLinear(10, 20)
255+
matches = list(match_named_modules(linear, ["re:.*"]))
256+
assert len(matches) == 0
257+
239258

240259
class TestMatchNamedParameters:
241260
"""Test cases for match_named_parameters function"""
@@ -298,6 +317,16 @@ def test_warn_on_fail_parameters(self, mock_logger):
298317
assert "Could not match" in warning_msg
299318
assert "nonexistent.param" in warning_msg
300319

320+
def test_internal_match(self):
321+
"""Test not matching internal modules"""
322+
323+
class InternalLinear(InternalModule, nn.Linear):
324+
pass
325+
326+
linear = InternalLinear(10, 20)
327+
matches = list(match_named_parameters(linear, ["re:.*"]))
328+
assert len(matches) == 0
329+
301330

302331
class TestMatchModulesSet:
303332
"""Test cases for match_modules_set function"""
@@ -377,6 +406,16 @@ def test_module_set_with_ignore(self):
377406
# Should have 2 sets (layers 1 and 2, but not 0)
378407
assert len(matches) == 2
379408

409+
def test_internal_match(self):
410+
"""Test not matching internal modules"""
411+
412+
class InternalLinear(InternalModule, nn.Linear):
413+
pass
414+
415+
linear = InternalLinear(10, 20)
416+
matches = list(match_modules_set(linear, ["re:.*"]))
417+
assert len(matches) == 0
418+
380419

381420
class TestIntegration:
382421
"""Integration tests combining multiple functions"""

0 commit comments

Comments
 (0)