Skip to content

Commit 2fcfb7f

Browse files
committed
Add tests on the private functions in base solver
1 parent 22e3c47 commit 2fcfb7f

File tree

1 file changed

+150
-1
lines changed

1 file changed

+150
-1
lines changed

tests/test_solver_batched.py

Lines changed: 150 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
device = "cpu"
3333

34-
NUM_STEPS = 10
34+
NUM_STEPS = 2
3535
BATCHES = [25_000, 50_000, 75_000]
3636
MODELS = ["offset1-model", "offset10-model", "offset40-model-4x-subsample"]
3737

@@ -341,3 +341,152 @@ def test_batched_transform_multi_session(data_name, model_name, padding,
341341

342342
assert embedding_batched.shape == embedding.shape
343343
assert np.allclose(embedding_batched, embedding, rtol=1e-4, atol=1e-4)
344+
345+
346+
@pytest.mark.parametrize(
347+
"batch_start_idx, batch_end_idx, offset, num_samples, expected_exception",
348+
[
349+
# Valid indices
350+
(0, 5, cebra.data.Offset(1, 1), 10, None),
351+
(2, 8, cebra.data.Offset(2, 2), 10, None),
352+
# Negative indices
353+
(-1, 5, cebra.data.Offset(1, 1), 10, ValueError),
354+
(0, -5, cebra.data.Offset(1, 1), 10, ValueError),
355+
# Start index greater than end index
356+
(5, 3, cebra.data.Offset(1, 1), 10, ValueError),
357+
# End index out of bounds
358+
(0, 11, cebra.data.Offset(1, 1), 10, ValueError),
359+
# Batch size smaller than offset
360+
(0, 2, cebra.data.Offset(3, 3), 10, ValueError),
361+
],
362+
)
363+
def test_check_indices(batch_start_idx, batch_end_idx, offset, num_samples,
364+
expected_exception):
365+
if expected_exception:
366+
with pytest.raises(expected_exception):
367+
cebra.solver.base._check_indices(batch_start_idx, batch_end_idx,
368+
offset, num_samples)
369+
else:
370+
cebra.solver.base._check_indices(batch_start_idx, batch_end_idx, offset,
371+
num_samples)
372+
373+
374+
@pytest.mark.parametrize(
375+
"batch_start_idx, batch_end_idx, num_samples, expected_exception",
376+
[
377+
# First batch
378+
(0, 6, 12, 8),
379+
# Last batch
380+
(6, 12, 12, 8),
381+
# Middle batch
382+
(3, 9, 12, 6),
383+
# Invalid start index
384+
(-1, 3, 4, ValueError),
385+
# Invalid end index
386+
(3, -10, 4, ValueError),
387+
# Start index greater than end index
388+
(5, 3, 4, ValueError),
389+
# End index out of bounds
390+
(0, 15, 12, ValueError),
391+
# Batch size smaller than batched_data
392+
(0, 2, 2, ValueError),
393+
# Batch size larger than batched_data
394+
(0, 12, 12, ValueError),
395+
],
396+
)
397+
def test_add_batched_zero_padding(batch_start_idx, batch_end_idx, num_samples,
398+
expected_exception):
399+
batched_data = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],
400+
[9.0, 10.0], [1.0, 2.0]])
401+
402+
model = create_model(model_name="offset5-model",
403+
input_dimension=batched_data.shape[1])
404+
offset = model.get_offset()
405+
406+
if expected_exception == ValueError:
407+
with pytest.raises(expected_exception):
408+
result = cebra.solver.base._add_batched_zero_padding(
409+
batched_data, offset, batch_start_idx, batch_end_idx,
410+
num_samples)
411+
else:
412+
result = cebra.solver.base._add_batched_zero_padding(
413+
batched_data, offset, batch_start_idx, batch_end_idx, num_samples)
414+
assert result.shape[0] == expected_exception
415+
416+
417+
@pytest.mark.parametrize(
418+
"pad_before_transform, expected_exception",
419+
[
420+
# Valid batched inputs
421+
(True, None),
422+
# No padding
423+
(False, None),
424+
],
425+
)
426+
def test_transform(pad_before_transform, expected_exception):
427+
inputs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],
428+
[9.0, 10.0], [1.0, 2.0], [3.0, 4.0], [5.0, 6.0],
429+
[7.0, 8.0], [9.0, 10.0], [1.0, 2.0], [3.0, 4.0],
430+
[5.0, 6.0], [7.0, 8.0], [9.0, 10.0]])
431+
model = create_model(model_name="offset5-model",
432+
input_dimension=inputs.shape[1])
433+
offset = model.get_offset()
434+
435+
result = cebra.solver.base._transform(
436+
model=model,
437+
inputs=inputs,
438+
pad_before_transform=pad_before_transform,
439+
offset=offset,
440+
)
441+
if pad_before_transform:
442+
assert result.shape[0] == inputs.shape[0]
443+
else:
444+
assert result.shape[0] == inputs.shape[0] - len(offset) + 1
445+
446+
447+
@pytest.mark.parametrize(
448+
"batch_size, pad_before_transform, expected_exception",
449+
[
450+
# Valid batched inputs
451+
(6, True, None),
452+
# Invalid batch size (too large)
453+
(12, True, ValueError),
454+
# Invalid batch size (too small)
455+
(2, True, ValueError),
456+
# Last batch size incomplete
457+
(5, True, None),
458+
# No padding
459+
(6, False, None),
460+
],
461+
)
462+
def test_batched_transform(batch_size, pad_before_transform,
463+
expected_exception):
464+
inputs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],
465+
[9.0, 10.0], [1.0, 2.0], [3.0, 4.0], [5.0, 6.0],
466+
[7.0, 8.0], [9.0, 10.0], [1.0, 2.0], [3.0, 4.0],
467+
[5.0, 6.0], [7.0, 8.0], [9.0, 10.0]])
468+
model = create_model(model_name="offset5-model",
469+
input_dimension=inputs.shape[1])
470+
offset = model.get_offset()
471+
472+
if expected_exception:
473+
with pytest.raises(expected_exception):
474+
cebra.solver.base._batched_transform(
475+
model=model,
476+
inputs=inputs,
477+
batch_size=batch_size,
478+
pad_before_transform=pad_before_transform,
479+
offset=offset,
480+
)
481+
else:
482+
result = cebra.solver.base._batched_transform(
483+
model=model,
484+
inputs=inputs,
485+
batch_size=batch_size,
486+
pad_before_transform=pad_before_transform,
487+
offset=offset,
488+
)
489+
if pad_before_transform:
490+
assert result.shape[0] == inputs.shape[0]
491+
else:
492+
assert result.shape[0] == inputs.shape[0] - len(offset) + 1

0 commit comments

Comments
 (0)