Skip to content

Commit 9875a38

Browse files
committed
Add some tests on transform() with xCEBRA
1 parent 64d1db8 commit 9875a38

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

tests/test_integration_xcebra.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pickle
22

3+
import numpy as np
34
import pytest
45
import torch
56

@@ -150,3 +151,20 @@ def test_synthetic_data_training(synthetic_data, device):
150151
assert Z2_hat.shape == Z2.shape, f"Incorrect Z2 embedding dimension: {Z2_hat.shape}"
151152
assert not torch.isnan(Z1_hat).any(), "NaN values in Z1 embedding"
152153
assert not torch.isnan(Z2_hat).any(), "NaN values in Z2 embedding"
154+
155+
# Test the transform
156+
solver.model.split_outputs = False
157+
transform_embedding = solver.transform(data.neural.to(device))
158+
assert transform_embedding.shape[
159+
1] == n_latents, "Incorrect embedding dimension"
160+
assert not torch.isnan(transform_embedding).any(), "NaN values in embedding"
161+
assert np.allclose(embedding, transform_embedding, rtol=1e-02)
162+
163+
# Test the transform with batching
164+
batched_embedding = solver.transform(data.neural.to(device), batch_size=512)
165+
assert batched_embedding.shape[
166+
1] == n_latents, "Incorrect embedding dimension"
167+
assert not torch.isnan(batched_embedding).any(), "NaN values in embedding"
168+
assert np.allclose(embedding, batched_embedding, rtol=1e-02)
169+
170+
assert np.allclose(transform_embedding, batched_embedding, rtol=1e-02)

0 commit comments

Comments
 (0)