|
59 | 59 | cebra.data.ContinuousMultiSessionDataLoader, "offset1-model"), |
60 | 60 | ("demo-continuous-multisession", |
61 | 61 | cebra.data.ContinuousMultiSessionDataLoader, "offset10-model"), |
62 | | - ("demo-discrete-multisession", |
63 | | - cebra.data.DiscreteMultiSessionDataLoader, "offset1-model"), |
64 | | - ("demo-discrete-multisession", |
65 | | - cebra.data.DiscreteMultiSessionDataLoader, "offset10-model"), |
| 62 | + ("demo-discrete-multisession", cebra.data.DiscreteMultiSessionDataLoader, |
| 63 | + "offset1-model"), |
| 64 | + ("demo-discrete-multisession", cebra.data.DiscreteMultiSessionDataLoader, |
| 65 | + "offset10-model"), |
66 | 66 | ]: |
67 | 67 | multi_session_tests.append((*args, cebra.solver.MultiSessionSolver)) |
68 | 68 |
|
69 | 69 |
|
70 | | - |
71 | 70 | def _get_loader(data, loader_initfunc): |
72 | 71 | kwargs = dict(num_steps=5, batch_size=32) |
73 | 72 | loader = loader_initfunc(data, **kwargs) |
@@ -168,7 +167,7 @@ def test_single_session(data_name, loader_initfunc, model_architecture, |
168 | 167 |
|
169 | 168 | assert solver.num_sessions is None |
170 | 169 | assert solver.n_features == X.shape[1] |
171 | | - |
| 170 | + |
172 | 171 | embedding = solver.transform(X) |
173 | 172 | assert isinstance(embedding, torch.Tensor) |
174 | 173 | assert embedding.shape == (X.shape[0], OUTPUT_DIMENSION) |
@@ -527,158 +526,6 @@ def test_multi_session_2(data_name, loader_initfunc, solver_initfunc): |
527 | 526 |
|
528 | 527 | solver.fit(loader) |
529 | 528 |
|
530 | | - assert solver.num_sessions == 3 |
531 | | - assert solver.n_features == [X[i].shape[1] for i in range(len(X))] |
532 | | - |
533 | | - embedding = solver.transform(X[0], session_id=0) |
534 | | - assert isinstance(embedding, torch.Tensor) |
535 | | - assert embedding.shape == (X[0].shape[0], OUTPUT_DIMENSION) |
536 | | - embedding = solver.transform(X[1], session_id=1) |
537 | | - assert isinstance(embedding, torch.Tensor) |
538 | | - assert embedding.shape == (X[1].shape[0], OUTPUT_DIMENSION) |
539 | | - embedding = solver.transform(X[0], session_id=0, pad_before_transform=False) |
540 | | - assert isinstance(embedding, torch.Tensor) |
541 | | - assert embedding.shape == (X[0].shape[0] - |
542 | | - len(solver.model[0].get_offset()) + 1, |
543 | | - OUTPUT_DIMENSION) |
544 | | - |
545 | | - with pytest.raises(ValueError, match="torch.Tensor"): |
546 | | - embedding = solver.transform(X[0].numpy(), session_id=0) |
547 | | - |
548 | | - with pytest.raises(ValueError, match="shape"): |
549 | | - embedding = solver.transform(X[1], session_id=0) |
550 | | - with pytest.raises(ValueError, match="shape"): |
551 | | - embedding = solver.transform(X[0], session_id=1) |
552 | | - |
553 | | - with pytest.raises(RuntimeError, match="No.*session_id"): |
554 | | - embedding = solver.transform(X[0]) |
555 | | - with pytest.raises(ValueError, match="single.*session"): |
556 | | - embedding = solver.transform(X) |
557 | | - with pytest.raises(RuntimeError, match="Invalid.*session_id"): |
558 | | - embedding = solver.transform(X[0], session_id=5) |
559 | | - with pytest.raises(RuntimeError, match="Invalid.*session_id"): |
560 | | - embedding = solver.transform(X[0], session_id=-1) |
561 | | - |
562 | | - for param in solver.parameters(session_id=0): |
563 | | - assert isinstance(param, torch.Tensor) |
564 | | - |
565 | | - fitted_solver = copy.deepcopy(solver) |
566 | | - with tempfile.TemporaryDirectory() as temp_dir: |
567 | | - solver.save(temp_dir) |
568 | | - solver.load(temp_dir) |
569 | | - _assert_equal(fitted_solver, solver) |
570 | | - |
571 | | - |
572 | | -@pytest.mark.parametrize( |
573 | | - "inputs, add_padding, offset, start_batch_idx, end_batch_idx, expected_output", |
574 | | - [ |
575 | | - # Test case 1: No padding |
576 | | - (torch.tensor([[1, 2], [3, 4], [5, 6]]), False, cebra.data.Offset( |
577 | | - 0, 1), 0, 2, torch.tensor([[1, 2], [3, 4]])), # first batch |
578 | | - (torch.tensor([[1, 2], [3, 4], [5, 6]]), False, cebra.data.Offset( |
579 | | - 0, 1), 1, 3, torch.tensor([[3, 4], [5, 6]])), # last batch |
580 | | - (torch.tensor( |
581 | | - [[1, 2], [3, 4], [5, 6], [7, 8]]), False, cebra.data.Offset( |
582 | | - 0, 1), 1, 3, torch.tensor([[3, 4], [5, 6]])), # middle batch |
583 | | -
|
584 | | - # Test case 2: First batch with padding |
585 | | - ( |
586 | | - torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), |
587 | | - True, |
588 | | - cebra.data.Offset(0, 1), |
589 | | - 0, |
590 | | - 2, |
591 | | - torch.tensor([[1, 2, 3], [4, 5, 6]]), |
592 | | - ), |
593 | | - ( |
594 | | - torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), |
595 | | - True, |
596 | | - cebra.data.Offset(1, 1), |
597 | | - 0, |
598 | | - 3, |
599 | | - torch.tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [7, 8, 9]]), |
600 | | - ), |
601 | | -
|
602 | | - # Test case 3: Last batch with padding |
603 | | - ( |
604 | | - torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), |
605 | | - True, |
606 | | - cebra.data.Offset(0, 1), |
607 | | - 1, |
608 | | - 3, |
609 | | - torch.tensor([[4, 5, 6], [7, 8, 9]]), |
610 | | - ), |
611 | | - ( |
612 | | - torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], |
613 | | - [13, 14, 15]]), |
614 | | - True, |
615 | | - cebra.data.Offset(1, 2), |
616 | | - 1, |
617 | | - 3, |
618 | | - torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]), |
619 | | - ), |
620 | | -
|
621 | | - # Test case 4: Middle batch with padding |
622 | | - ( |
623 | | - torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]), |
624 | | - True, |
625 | | - cebra.data.Offset(0, 1), |
626 | | - 1, |
627 | | - 3, |
628 | | - torch.tensor([[4, 5, 6], [7, 8, 9]]), |
629 | | - ), |
630 | | - ( |
631 | | - torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]), |
632 | | - True, |
633 | | - cebra.data.Offset(1, 1), |
634 | | - 1, |
635 | | - 3, |
636 | | - torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), |
637 | | - ), |
638 | | - ( |
639 | | - torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], |
640 | | - [13, 14, 15]]), |
641 | | - True, |
642 | | - cebra.data.Offset(0, 1), |
643 | | - 2, |
644 | | - 4, |
645 | | - torch.tensor([[7, 8, 9], [10, 11, 12]]), |
646 | | - ), |
647 | | - ( |
648 | | - torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]), |
649 | | - True, |
650 | | - cebra.data.Offset(0, 1), |
651 | | - 0, |
652 | | - 3, |
653 | | - torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), |
654 | | - ), |
655 | | -
|
656 | | - # Examples that throw an error: |
657 | | -
|
658 | | - # Padding without offset (should raise an error) |
659 | | - (torch.tensor([[1, 2]]), True, None, 0, 2, ValueError), |
660 | | - # Negative start_batch_idx or end_batch_idx (should raise an error) |
661 | | - (torch.tensor([[1, 2]]), False, cebra.data.Offset( |
662 | | - 0, 1), -1, 2, ValueError), |
663 | | - # out of bound indices because offset is too large |
664 | | - (torch.tensor([[1, 2], [3, 4]]), True, cebra.data.Offset( |
665 | | - 5, 5), 1, 2, ValueError), |
666 | | - # Batch length is smaller than offset. |
667 | | - (torch.tensor([[1, 2], [3, 4]]), False, cebra.data.Offset( |
668 | | - 0, 1), 0, 1, ValueError), # first batch |
669 | | - ], |
670 | | -) |
671 | | -def test_get_batch(inputs, add_padding, offset, start_batch_idx, end_batch_idx, |
672 | | - expected_output): |
673 | | - if expected_output == ValueError: |
674 | | - with pytest.raises(ValueError): |
675 | | - cebra.solver.base._get_batch(inputs, offset, start_batch_idx, |
676 | | - end_batch_idx, add_padding) |
677 | | - else: |
678 | | - result = cebra.solver.base._get_batch(inputs, offset, start_batch_idx, |
679 | | - end_batch_idx, add_padding) |
680 | | - assert torch.equal(result, expected_output) |
681 | | - |
682 | 529 |
|
683 | 530 | def create_model(model_name, input_dimension): |
684 | 531 | return cebra.models.init(model_name, |
|
0 commit comments