File tree Expand file tree Collapse file tree 2 files changed +9
-7
lines changed
src/compressed_tensors/transform/factory Expand file tree Collapse file tree 2 files changed +9
-7
lines changed Original file line number Diff line number Diff line change 18
18
19
19
import torch
20
20
import torch .nn .utils .parametrize as P
21
- from compressed_tensors import InternalModule
22
- from compressed_tensors .quantization .lifecycle import is_target # TODO: move to utils
21
+ from compressed_tensors import InternalModule , match_named_modules
23
22
from compressed_tensors .registry .registry import RegistryMixin , T
24
23
from compressed_tensors .transform import (
25
24
TransformArgs ,
@@ -91,9 +90,8 @@ def apply_to_model(self, model: Module):
91
90
:param model: module to apply transforms to
92
91
"""
93
92
for arg in self .scheme .apply :
94
- for name , module in list (model .named_modules ()):
95
- if is_target (name , module , arg .targets , arg .ignore ):
96
- self ._apply_to_module (module , arg )
93
+ for _ , module in match_named_modules (model , arg .targets , arg .ignore ):
94
+ self ._apply_to_module (module , arg )
97
95
98
96
self ._update_tied_weights ()
99
97
Original file line number Diff line number Diff line change 23
23
apply_transform_weight ,
24
24
get_transform_size ,
25
25
)
26
- from compressed_tensors .utils import get_execution_device , get_offloaded_device
26
+ from compressed_tensors .utils import (
27
+ get_execution_device ,
28
+ get_offloaded_device ,
29
+ match_modules_set ,
30
+ )
27
31
from compressed_tensors .utils .helpers import ParameterizedDefaultDict
28
32
from torch import Tensor , device , dtype
29
- from torch .nn import Linear , Module , Parameter
33
+ from torch .nn import Module , Parameter
30
34
31
35
32
36
@TransformFactory .register ("hadamard" )
You can’t perform that action at this time.
0 commit comments