Skip to content

Commit 357bb6a

Browse files
authored
test: adjust assertion for my_fn_calls based on optax version in test/test_optimizers.py (#2140)
* Revert "test: update test effected by the `optax==0.2.7` release (#2137)" This reverts commit 7d67ff5. * test: adjust assertion for `my_fn_calls` based on optax version in `test/test_optimizers.py` * fix: remove test skip * fix: remove optax version checks from the test
1 parent 7d67ff5 commit 357bb6a

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

test/test_optimizers.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,4 +123,8 @@ def my_fn(state, g):
123123
state = my_fn(state, jnp.ones(10) * 2.0)
124124
state = my_fn(state, jnp.ones(10) * 3.0)
125125

126-
assert my_fn_calls == 1
126+
if uses_value_arg:
127+
# Dtype is different on the first call vs the rest of the calls
128+
assert my_fn_calls in (1, 2)
129+
else:
130+
assert my_fn_calls == 1

0 commit comments

Comments
 (0)