Skip to content

Commit d11b506

Browse files
tests: fix all_close to respect max 2 positional args (#1074)
1 parent 0bf7198 commit d11b506

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

tests/test_functional.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,12 @@
2626

2727

2828
def assert_all_approx_close(a, b, rtol=1e-3, atol=1e-3, count=0, throw=True):
29-
idx = torch.isclose(a, b, rtol, atol)
29+
idx = torch.isclose(a, b, rtol=rtol, atol=atol)
3030
sumval = (idx == 0).sum().item()
3131
if sumval > count:
3232
if throw:
3333
print(f"Too many values not close: assert {sumval} < {count}")
34-
torch.testing.assert_close(a, b, rtol, atol)
34+
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
3535

3636
return sumval
3737

tests/test_modules.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,11 @@ def get_args():
4242

4343

4444
def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
45-
idx = torch.isclose(a, b, rtol, atol)
45+
idx = torch.isclose(a, b, rtol=rtol, atol=atol)
4646
sumval = (idx == 0).sum().item()
4747
if sumval > count:
4848
print(f"Too many values not close: assert {sumval} < {count}")
49-
torch.testing.assert_close(a, b, rtol, atol)
49+
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
5050

5151

5252
class LinearFunction(torch.autograd.Function):

tests/test_optim.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
145145

146146
# since Lion can have pretty noisy updates where things lie at the boundary
147147
# allow up to 10 errors for Lion
148-
assert_most_approx_close(p1, p2.float(), atol, rtol, max_error_count=10)
148+
assert_most_approx_close(p1, p2.float(), atol=atol, rtol=rtol, max_error_count=10)
149149

150150
if i % (k // 5) == 0 and i > 0:
151151
path = get_temp_dir()
@@ -157,7 +157,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
157157
rm_path(path)
158158
# since Lion can have pretty noisy updates where things lie at the boundary
159159
# allow up to 10 errors for Lion
160-
assert_most_approx_close(p1, p2.float(), atol, rtol, max_error_count=10)
160+
assert_most_approx_close(p1, p2.float(), atol=atol, rtol=rtol, max_error_count=10)
161161
for name1, name2 in str2statenames[optim_name]:
162162
# since Lion can have pretty noisy updates where things lie at the boundary
163163
# allow up to 10 errors for Lion

0 commit comments

Comments
 (0)