19
19
TransformConfig ,
20
20
TransformFactory ,
21
21
TransformScheme ,
22
+ apply_transform_config ,
22
23
)
23
24
from compressed_tensors .utils import offloaded_dispatch
24
25
from tests .testing_utils import requires_accelerate , requires_gpu
25
26
26
27
27
- def scheme_kwargs ():
28
- all_types = TransformFactory .registered_names ()
29
- base = [{"type" : type } for type in all_types ]
30
- randomized = [{"type" : type , "randomize" : True } for type in all_types ]
31
- return base + randomized
32
-
33
-
34
- @pytest .mark .parametrize ("scheme_kwargs" , scheme_kwargs ())
35
- def test_correctness_linear (scheme_kwargs ):
28
+ @pytest .mark .parametrize ("type" , ("hadamard" , "random-hadamard" ))
29
+ @pytest .mark .parametrize ("randomized" , (True , False ))
30
+ def test_correctness_linear (type , randomized ):
36
31
size = (4 , 8 )
37
32
module = torch .nn .Linear (* size , bias = True )
38
- scheme = TransformScheme (** scheme_kwargs )
33
+ scheme = TransformScheme (type = type , randomized = randomized )
39
34
factory = TransformFactory .from_scheme (scheme , name = "" )
40
35
41
36
input_tfm = factory .create_transform (
@@ -59,8 +54,9 @@ def test_correctness_linear(scheme_kwargs):
59
54
assert torch .allclose (true_output , output , atol = 1e-5 , rtol = 0.0 )
60
55
61
56
62
- @pytest .mark .parametrize ("scheme_kwargs" , scheme_kwargs ())
63
- def test_correctness_model (scheme_kwargs , model_apply , offload = False ):
57
+ @pytest .mark .parametrize ("type" , ("hadamard" , "random-hadamard" ))
58
+ @pytest .mark .parametrize ("randomized" , (True , False ))
59
+ def test_correctness_model (type , randomized , model_apply , offload = False ):
64
60
# load model
65
61
model = model_apply [0 ]
66
62
if offload :
@@ -75,15 +71,10 @@ def test_correctness_model(scheme_kwargs, model_apply, offload=False):
75
71
# apply transforms
76
72
config = TransformConfig (
77
73
config_groups = {
78
- "" : TransformScheme (
79
- ** scheme_kwargs ,
80
- apply = model_apply [1 ],
81
- )
74
+ "" : TransformScheme (type = type , randomized = randomized , apply = model_apply [1 ])
82
75
}
83
76
)
84
- for name , scheme in config .config_groups .items ():
85
- factory = TransformFactory .from_scheme (scheme , name = name )
86
- factory .apply_to_model (model )
77
+ apply_transform_config (model , config )
87
78
88
79
# compare outputs
89
80
output = model (input )
@@ -92,6 +83,7 @@ def test_correctness_model(scheme_kwargs, model_apply, offload=False):
92
83
93
84
@requires_gpu
94
85
@requires_accelerate ()
95
- @pytest .mark .parametrize ("scheme_kwargs" , scheme_kwargs ())
96
- def test_correctness_model_offload (scheme_kwargs , model_apply ):
97
- test_correctness_model (scheme_kwargs , model_apply , offload = True )
86
+ @pytest .mark .parametrize ("type" , ("hadamard" , "random-hadamard" ))
87
+ @pytest .mark .parametrize ("randomized" , (True , False ))
88
+ def test_correctness_model_offload (type , randomized , model_apply ):
89
+ test_correctness_model (type , randomized , model_apply , offload = True )
0 commit comments