@@ -64,8 +64,7 @@ def test_correctness_linear(scheme):
64
64
input_transformed = input_tfm (input )
65
65
weight_transformed = w_out_tfm (w_in_tfm (module .weight ))
66
66
output = output_tfm (input_transformed @ weight_transformed .T )
67
-
68
- torch .allclose (true_output , output , atol = 1e-7 , rtol = 0.0 )
67
+ assert torch .allclose (true_output , output , atol = 1e-5 , rtol = 0.0 )
69
68
70
69
71
70
@pytest .mark .parametrize (
@@ -74,14 +73,24 @@ def test_correctness_linear(scheme):
74
73
)
75
74
def test_correctness_model (scheme , offload = False ):
76
75
# load model
77
- model = TransformableModel (2 , 4 , 8 , 16 )
76
+ model = TransformableModel (2 , 4 , 8 , 16 , 32 , 64 )
78
77
if offload :
79
78
model = force_cpu_offload (model , torch .device ("cuda" ))
80
79
81
80
# create factory
82
81
scheme .apply = [
83
- TransformArgs (targets = "fcs.0" , location = "input" ),
84
- TransformArgs (targets = "fcs.2" , location = "output" , inverse = True ),
82
+ # weight output -> input
83
+ TransformArgs (targets = "fcs.0" , location = "weight_output" ),
84
+ TransformArgs (targets = "fcs.1" , location = "input" , inverse = True ),
85
+ # output -> weight input
86
+ TransformArgs (targets = "fcs.1" , location = "output" ),
87
+ TransformArgs (targets = "fcs.2" , location = "weight_input" , inverse = True ),
88
+ # output -> input
89
+ TransformArgs (targets = "fcs.2" , location = "output" ),
90
+ TransformArgs (targets = "fcs.3" , location = "input" , inverse = True ),
91
+ # weight output -> weight input
92
+ TransformArgs (targets = "fcs.3" , location = "weight_output" ),
93
+ TransformArgs (targets = "fcs.4" , location = "weight_input" , inverse = True ),
85
94
]
86
95
factory = TransformFactory .from_scheme (scheme , name = "" )
87
96
@@ -94,7 +103,7 @@ def test_correctness_model(scheme, offload=False):
94
103
true_output = model (input )
95
104
factory .apply_to_model (model )
96
105
output = model (input )
97
- torch .allclose (true_output , output , atol = 1e-7 , rtol = 0.0 )
106
+ assert torch .allclose (true_output , output , atol = 1e-5 , rtol = 0.0 )
98
107
99
108
100
109
@requires_gpu
0 commit comments