Skip to content

Commit aec9a79

Browse files
authored
Fix test after removing contiguous() (#2751)
Summary: Didn't repro the error before due to some installation cache Test Plan: python test/quantization/quantize_/workflows/float8/test_float8_tensor.py Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2751, branch: jerryzh168/stack/27
1 parent 10a0bdd commit aec9a79

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,14 +193,14 @@ def test_slice(self, granularity):
193193
# does not differ too much
194194
input = torch.randn(2, 256, dtype=dtype, device=device)
195195
res_ref = dummy1(input)
196-
dummy.weight = torch.nn.Parameter(weight1, requires_grad=False)
196+
dummy.weight = torch.nn.Parameter(weight1.contiguous(), requires_grad=False)
197197
res = dummy(input)
198198
sqnr = compute_error(res, res_ref)
199199
self.assertTrue(sqnr > 25, f"sqnr: {sqnr}")
200200

201201
input = torch.randn(2, 128, dtype=dtype, device=device)
202202
res_ref = dummy2(input)
203-
dummy.weight = torch.nn.Parameter(weight2, requires_grad=False)
203+
dummy.weight = torch.nn.Parameter(weight2.contiguous(), requires_grad=False)
204204
res = dummy(input)
205205
sqnr = compute_error(res, res_ref)
206206
self.assertTrue(sqnr > 15, f"sqnr: {sqnr}")

0 commit comments

Comments
 (0)