Skip to content

Commit 2787055

Browse files
committed
test bug!
1 parent badf3e9 commit 2787055

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

tests/test_models.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def test_resnet():
6262

6363
out = net(t, x, q=q, a=a, key=key)
6464
assert out.shape == x.shape
65-
assert jnp.isfinite(out)
65+
assert jnp.all(jnp.isfinite(out))
6666

6767
net = ResidualNetwork(
6868
x.size,
@@ -182,11 +182,10 @@ def test_mixer():
182182
assert jnp.all(jnp.isfinite(out))
183183

184184

185-
def test_unet():
185+
def test_dit():
186186

187187
key = jr.key(0)
188188

189-
hidden_size = 32
190189
img_size = 32
191190
n_channels = 1
192191
embed_dim = 32

0 commit comments

Comments
 (0)