Skip to content

Commit 6901e02

Browse files
committed
fix tests
Signed-off-by: Kyle Sayers <[email protected]>
1 parent aa7d21b commit 6901e02

File tree

3 files changed

+16
-11
lines changed

3 files changed

+16
-11
lines changed

src/compressed_tensors/transform/factory/random_hadamard.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,6 @@ class RandomHadamardFactory(HadamardFactory):
2929
"""
3030

3131
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
32-
for key in self.weights.keys():
33-
if key[0] == size:
34-
return self.weights[key].to(dtype=dtype, device=device)
35-
3632
data = random_hadamard_matrix(size) # seed
3733
data = data.to(dtype=dtype, device=device)
3834
return Parameter(data, requires_grad=self.scheme.requires_grad)

src/compressed_tensors/transform/utils/hadamard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def _matmul_hadU(X, transpose=False) -> torch.Tensor:
128128
input = hadK.view(1, K, K).to(input) @ input
129129

130130
# normalize
131-
return input.view(X.shape) / torch.tensor(n).sqrt()
131+
return input.view(X.shape)
132132

133133

134134
def _is_pow2(n: int) -> bool:

tests/test_transform/factory/test_correctness.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,7 @@ def test_correctness_linear(scheme):
6464
input_transformed = input_tfm(input)
6565
weight_transformed = w_out_tfm(w_in_tfm(module.weight))
6666
output = output_tfm(input_transformed @ weight_transformed.T)
67-
68-
torch.allclose(true_output, output, atol=1e-7, rtol=0.0)
67+
assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0)
6968

7069

7170
@pytest.mark.parametrize(
@@ -74,14 +73,24 @@ def test_correctness_linear(scheme):
7473
)
7574
def test_correctness_model(scheme, offload=False):
7675
# load model
77-
model = TransformableModel(2, 4, 8, 16)
76+
model = TransformableModel(2, 4, 8, 16, 32, 64)
7877
if offload:
7978
model = force_cpu_offload(model, torch.device("cuda"))
8079

8180
# create factory
8281
scheme.apply = [
83-
TransformArgs(targets="fcs.0", location="input"),
84-
TransformArgs(targets="fcs.2", location="output", inverse=True),
82+
# weight output -> input
83+
TransformArgs(targets="fcs.0", location="weight_output"),
84+
TransformArgs(targets="fcs.1", location="input", inverse=True),
85+
# output -> weight input
86+
TransformArgs(targets="fcs.1", location="output"),
87+
TransformArgs(targets="fcs.2", location="weight_input", inverse=True),
88+
# output -> input
89+
TransformArgs(targets="fcs.2", location="output"),
90+
TransformArgs(targets="fcs.3", location="input", inverse=True),
91+
# weight output -> weight input
92+
TransformArgs(targets="fcs.3", location="weight_output"),
93+
TransformArgs(targets="fcs.4", location="weight_input", inverse=True),
8594
]
8695
factory = TransformFactory.from_scheme(scheme, name="")
8796

@@ -94,7 +103,7 @@ def test_correctness_model(scheme, offload=False):
94103
true_output = model(input)
95104
factory.apply_to_model(model)
96105
output = model(input)
97-
torch.allclose(true_output, output, atol=1e-7, rtol=0.0)
106+
assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0)
98107

99108

100109
@requires_gpu

0 commit comments

Comments
 (0)