Skip to content

Commit 966b50e

Browse files
committed
small cleanup
Signed-off-by: Kyle Sayers <[email protected]>
1 parent a7ed09f commit 966b50e

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

src/compressed_tensors/transform/factory/base.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@
1818

1919
import torch
2020
import torch.nn.utils.parametrize as P
21-
from compressed_tensors import InternalModule
22-
from compressed_tensors.quantization.lifecycle import is_target # TODO: move to utils
21+
from compressed_tensors import InternalModule, match_named_modules
2322
from compressed_tensors.registry.registry import RegistryMixin, T
2423
from compressed_tensors.transform import (
2524
TransformArgs,
@@ -91,9 +90,8 @@ def apply_to_model(self, model: Module):
9190
:param model: module to apply transforms to
9291
"""
9392
for arg in self.scheme.apply:
94-
for name, module in list(model.named_modules()):
95-
if is_target(name, module, arg.targets, arg.ignore):
96-
self._apply_to_module(module, arg)
93+
for _, module in match_named_modules(model, arg.targets, arg.ignore):
94+
self._apply_to_module(module, arg)
9795

9896
self._update_tied_weights()
9997

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,14 @@
2323
apply_transform_weight,
2424
get_transform_size,
2525
)
26-
from compressed_tensors.utils import get_execution_device, get_offloaded_device
26+
from compressed_tensors.utils import (
27+
get_execution_device,
28+
get_offloaded_device,
29+
match_modules_set,
30+
)
2731
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
2832
from torch import Tensor, device, dtype
29-
from torch.nn import Linear, Module, Parameter
33+
from torch.nn import Module, Parameter
3034

3135

3236
@TransformFactory.register("hadamard")

0 commit comments

Comments
 (0)