Skip to content

Commit d94f655

Browse files
committed
Ensure match_targets doesn't return duplicates
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
1 parent d22149f commit d94f655

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

src/compressed_tensors/quantization/lifecycle/apply.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def apply_quantization_config(
154154
):
155155
# mark modules to be quantized by adding
156156
# quant scheme to the matching layers
157-
matched_targets = list(match_targets(name, submodule, target_to_scheme))
157+
matched_targets = match_targets(name, submodule, target_to_scheme)
158158
scheme = _scheme_from_targets(target_to_scheme, matched_targets, name)
159159
if run_compressed:
160160
format = config.format

src/compressed_tensors/utils/match.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -117,19 +117,25 @@ def match_named_parameters(
117117

118118
def match_targets(
119119
name: str, module: torch.nn.Module, targets: Iterable[str]
120-
) -> Generator[str]:
120+
) -> List[str]:
121121
"""
122-
Yields the targets that match the given name and module.
122+
Returns the targets that match the given name and module.
123123
Outputs are ordered by type: exact name match, regex name match, class name match
124124
"""
125+
if isinstance(module, InternalModule):
126+
return []
127+
125128
targets = sorted(targets, key=lambda x: ("re:" in x, x))
129+
matched_targets = []
126130
for target in targets:
127131
if _match_name(name, target):
128-
yield target
132+
matched_targets.append(target)
129133

130134
for target in targets:
131-
if _match_class(module, target):
132-
yield target
135+
if _match_class(module, target) and target not in matched_targets:
136+
matched_targets.append(target)
137+
138+
return matched_targets
133139

134140

135141
def match_modules_set(

0 commit comments

Comments
 (0)