Skip to content

Commit d1d0dfe

Browse files
Move deprecated test
1 parent 69d4958 commit d1d0dfe

File tree

2 files changed

+31
-31
lines changed

2 files changed

+31
-31
lines changed

tests/test_deprecated.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,3 +142,34 @@ def test_matmul_fp8(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
142142
grad_err = (gradB1 - gradB2).abs().mean()
143143
assert grad_err.item() < 0.003
144144
torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3)
145+
146+
147+
@pytest.mark.deprecated
148+
def test_fp8linear():
149+
b = 10
150+
h = 1024
151+
inp = torch.randn(b, h).cuda()
152+
fp32 = torch.nn.Linear(h, h * 2).cuda()
153+
fp8 = bnb.research.nn.LinearFP8Mixed(h, h * 2).cuda()
154+
fp32b = torch.nn.Linear(h * 2, h).cuda()
155+
fp8b = bnb.research.nn.LinearFP8Mixed(h * 2, h).cuda()
156+
157+
fp8.weight.data.copy_(fp32.weight.data)
158+
fp8.bias.data.copy_(fp32.bias.data)
159+
fp8b.weight.data.copy_(fp32b.weight.data)
160+
fp8b.bias.data.copy_(fp32b.bias.data)
161+
162+
a = fp32b(torch.nn.functional.gelu(fp32(inp)))
163+
b = fp8b(torch.nn.functional.gelu(fp8(inp)))
164+
165+
err = (a - b).abs().mean()
166+
167+
a.mean().backward()
168+
b.mean().backward()
169+
170+
graderr = (fp8.weight.grad - fp32.weight.grad).abs().mean()
171+
bgraderr = (fp8.bias.grad - fp32.bias.grad).abs().mean()
172+
173+
assert err < 0.05
174+
assert graderr < 0.00002
175+
assert bgraderr < 0.00002

tests/test_modules.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -343,37 +343,6 @@ def test_kbit_backprop(device, module):
343343
assert kbit[0].weight.grad is None or kbit[0].bias.grad.sum().item() == 0
344344

345345

346-
@pytest.mark.deprecated
347-
def test_fp8linear():
348-
b = 10
349-
h = 1024
350-
inp = torch.randn(b, h).cuda()
351-
fp32 = torch.nn.Linear(h, h * 2).cuda()
352-
fp8 = bnb.research.nn.LinearFP8Mixed(h, h * 2).cuda()
353-
fp32b = torch.nn.Linear(h * 2, h).cuda()
354-
fp8b = bnb.research.nn.LinearFP8Mixed(h * 2, h).cuda()
355-
356-
fp8.weight.data.copy_(fp32.weight.data)
357-
fp8.bias.data.copy_(fp32.bias.data)
358-
fp8b.weight.data.copy_(fp32b.weight.data)
359-
fp8b.bias.data.copy_(fp32b.bias.data)
360-
361-
a = fp32b(torch.nn.functional.gelu(fp32(inp)))
362-
b = fp8b(torch.nn.functional.gelu(fp8(inp)))
363-
364-
err = (a - b).abs().mean()
365-
366-
a.mean().backward()
367-
b.mean().backward()
368-
369-
graderr = (fp8.weight.grad - fp32.weight.grad).abs().mean()
370-
bgraderr = (fp8.bias.grad - fp32.bias.grad).abs().mean()
371-
372-
assert err < 0.05
373-
assert graderr < 0.00002
374-
assert bgraderr < 0.00002
375-
376-
377346
@pytest.mark.parametrize("device", get_available_devices())
378347
@pytest.mark.parametrize("embedding_dim", [64, 65])
379348
@pytest.mark.parametrize("input_shape", [(10,), (10, 10), (10, 10, 10)], ids=str)

0 commit comments

Comments
 (0)