Skip to content

Commit 9cfc70f

Browse files
timonmerkstes
authored andcommitted
add test for MixedDataLoader including additional keywords
1 parent 2804c3a commit 9cfc70f

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed

tests/test_loader.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,37 @@ def test_continuous(conditional, device, benchmark):
186186
benchmark(load_speed)
187187

188188

189+
@parametrize_device
190+
@pytest.mark.parametrize(
191+
"conditional, positive_sampling, discrete_sampling_prior",
192+
[
193+
("time", "discrete_variable", "empirical"),
194+
("time", "conditional", "empirical"),
195+
("time", "discrete_variable", "uniform"),
196+
("time", "conditional", "uniform"),
197+
("time_delta", "discrete_variable", "empirical"),
198+
("time_delta", "conditional", "empirical"),
199+
("time_delta", "discrete_variable", "uniform"),
200+
("time_delta", "conditional", "uniform"),
201+
],
202+
)
203+
def test_mixed(
204+
conditional, positive_sampling, discrete_sampling_prior, device, benchmark
205+
):
206+
dataset = RandomDataset(N=100, d=5, device=device)
207+
loader = cebra.data.MixedDataLoader(
208+
dataset=dataset,
209+
num_steps=10,
210+
batch_size=8,
211+
conditional=conditional,
212+
positive_sampling=positive_sampling,
213+
discrete_sampling_prior=discrete_sampling_prior,
214+
)
215+
_assert_dataset_on_correct_device(loader, device)
216+
load_speed = LoadSpeed(loader)
217+
benchmark(load_speed)
218+
219+
189220
def _check_attributes(obj, is_list=False):
190221
if is_list:
191222
for obj_ in obj:

0 commit comments

Comments
 (0)