Skip to content

Commit 965fd5d

Browse files
committed
test update
1 parent 4c11d6d commit 965fd5d

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tests/test_linear4bit.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,11 @@ def test_linear4_state_dict(quant_type, compress_statistics, bias):
5656
compute_dtype=compute_dtype,
5757
compress_statistics=compress_statistics,
5858
quant_type=quant_type,
59-
device=device,
59+
device='meta',
6060
)
6161
linear_q2.weight = weight2.to(device)
6262
if bias:
63-
linear_q2.bias.data = bias_data2
63+
linear_q2.bias = torch.nn.Parameter(bias_data2)
6464

6565
# matching
6666
a, b = linear_q.weight, linear_q2.weight
@@ -93,7 +93,7 @@ def test_linear4_state_dict(quant_type, compress_statistics, bias):
9393
assert torch.equal(a, b)
9494

9595
# Forward test
96-
x = torch.rand(42, linear_q.shape[-1], device=device)
96+
x = torch.rand(42, layer_shape[0], device=device)
9797
a = linear_q(x)
9898
b = linear_q2(x)
9999
assert a.device == b.device

0 commit comments

Comments
 (0)