|
59 | 59 | cebra.data.ContinuousMultiSessionDataLoader, "offset1-model"), |
60 | 60 | ("demo-continuous-multisession", |
61 | 61 | cebra.data.ContinuousMultiSessionDataLoader, "offset10-model"), |
| 62 | + ("demo-discrete-multisession", |
| 63 | + cebra.data.DiscreteMultiSessionDataLoader, "offset1-model"), |
| 64 | + ("demo-discrete-multisession", |
| 65 | + cebra.data.DiscreteMultiSessionDataLoader, "offset10-model"), |
62 | 66 | ]: |
63 | 67 | multi_session_tests.append((*args, cebra.solver.MultiSessionSolver)) |
64 | 68 |
|
65 | | -# multi_session_tests.append((*args, cebra.solver.MultiSessionAuxVariableSolver)) |
66 | | - |
67 | 69 |
|
68 | 70 | def _get_loader(data, loader_initfunc): |
69 | 71 | kwargs = dict(num_steps=5, batch_size=32) |
@@ -165,6 +167,28 @@ def test_single_session(data_name, loader_initfunc, model_architecture, |
165 | 167 |
|
166 | 168 | assert solver.num_sessions == None |
167 | 169 | assert solver.n_features == X.shape[1] |
| 170 | + |
| 171 | + embedding = solver.transform(X) |
| 172 | + assert isinstance(embedding, torch.Tensor) |
| 173 | + assert embedding.shape == (X.shape[0], OUTPUT_DIMENSION) |
| 174 | + embedding = solver.transform(torch.Tensor(X)) |
| 175 | + assert isinstance(embedding, torch.Tensor) |
| 176 | + assert embedding.shape == (X.shape[0], OUTPUT_DIMENSION) |
| 177 | + embedding = solver.transform(X, session_id=0) |
| 178 | + assert isinstance(embedding, torch.Tensor) |
| 179 | + assert embedding.shape == (X.shape[0], OUTPUT_DIMENSION) |
| 180 | + embedding = solver.transform(X, pad_before_transform=False) |
| 181 | + assert isinstance(embedding, torch.Tensor) |
| 182 | + assert embedding.shape == (X.shape[0] - len(offset) + 1, OUTPUT_DIMENSION) |
| 183 | + |
| 184 | + with pytest.raises(ValueError, match="torch.Tensor"): |
| 185 | + solver.transform(X.numpy()) |
| 186 | + with pytest.raises(RuntimeError, match="Invalid.*session_id"): |
| 187 | + embedding = solver.transform(X, session_id=2) |
| 188 | + |
| 189 | + for param in solver.parameters(): |
| 190 | + assert isinstance(param, torch.Tensor) |
| 191 | + |
168 | 192 |
|
169 | 193 | embedding = solver.transform(X) |
170 | 194 | assert isinstance(embedding, torch.Tensor) |
@@ -320,6 +344,183 @@ def test_multi_session(data_name, loader_initfunc, model_architecture, |
320 | 344 | assert solver.num_sessions == 3 |
321 | 345 | assert solver.n_features == [X[i].shape[1] for i in range(len(X))] |
322 | 346 |
|
| 347 | + embedding = solver.transform(X[0], session_id=0) |
| 348 | + assert isinstance(embedding, torch.Tensor) |
| 349 | + assert embedding.shape == (X[0].shape[0], OUTPUT_DIMENSION) |
| 350 | + embedding = solver.transform(X[1], session_id=1) |
| 351 | + assert isinstance(embedding, torch.Tensor) |
| 352 | + assert embedding.shape == (X[1].shape[0], OUTPUT_DIMENSION) |
| 353 | + embedding = solver.transform(X[0], session_id=0, pad_before_transform=False) |
| 354 | + assert isinstance(embedding, torch.Tensor) |
| 355 | + assert embedding.shape == (X[0].shape[0] - |
| 356 | + len(solver.model[0].get_offset()) + 1, |
| 357 | + OUTPUT_DIMENSION) |
| 358 | + |
| 359 | + with pytest.raises(ValueError, match="torch.Tensor"): |
| 360 | + embedding = solver.transform(X[0].numpy(), session_id=0) |
| 361 | + |
| 362 | + with pytest.raises(ValueError, match="shape"): |
| 363 | + embedding = solver.transform(X[1], session_id=0) |
| 364 | + with pytest.raises(ValueError, match="shape"): |
| 365 | + embedding = solver.transform(X[0], session_id=1) |
| 366 | + |
| 367 | + with pytest.raises(RuntimeError, match="No.*session_id"): |
| 368 | + embedding = solver.transform(X[0]) |
| 369 | + with pytest.raises(RuntimeError, match="single.*session"): |
| 370 | + embedding = solver.transform(X) |
| 371 | + with pytest.raises(RuntimeError, match="Invalid.*session_id"): |
| 372 | + embedding = solver.transform(X[0], session_id=5) |
| 373 | + with pytest.raises(RuntimeError, match="Invalid.*session_id"): |
| 374 | + embedding = solver.transform(X[0], session_id=-1) |
| 375 | + |
| 376 | + for param in solver.parameters(session_id=0): |
| 377 | + assert isinstance(param, torch.Tensor) |
| 378 | + |
| 379 | + with pytest.raises(RuntimeError, match="No.*session_id"): |
| 380 | + for param in solver.parameters(): |
| 381 | + assert isinstance(param, torch.Tensor) |
| 382 | + |
| 383 | + |
| 384 | +@pytest.mark.parametrize( |
| 385 | + "inputs, add_padding, offset, start_batch_idx, end_batch_idx, expected_output", |
| 386 | + [ |
| 387 | + # Test case 1: No padding |
| 388 | + (torch.tensor([[1, 2], [3, 4], [5, 6]]), False, cebra.data.Offset( |
| 389 | + 0, 1), 0, 2, torch.tensor([[1, 2], [3, 4]])), # first batch |
| 390 | + (torch.tensor([[1, 2], [3, 4], [5, 6]]), False, cebra.data.Offset( |
| 391 | + 0, 1), 1, 3, torch.tensor([[3, 4], [5, 6]])), # last batch |
| 392 | + (torch.tensor( |
| 393 | + [[1, 2], [3, 4], [5, 6], [7, 8]]), False, cebra.data.Offset( |
| 394 | + 0, 1), 1, 3, torch.tensor([[3, 4], [5, 6]])), # middle batch |
| 395 | +
|
| 396 | + # Test case 2: First batch with padding |
| 397 | + ( |
| 398 | + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), |
| 399 | + True, |
| 400 | + cebra.data.Offset(0, 1), |
| 401 | + 0, |
| 402 | + 2, |
| 403 | + torch.tensor([[1, 2, 3], [4, 5, 6]]), |
| 404 | + ), |
| 405 | + ( |
| 406 | + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), |
| 407 | + True, |
| 408 | + cebra.data.Offset(1, 1), |
| 409 | + 0, |
| 410 | + 3, |
| 411 | + torch.tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [7, 8, 9]]), |
| 412 | + ), |
| 413 | +
|
| 414 | + # Test case 3: Last batch with padding |
| 415 | + ( |
| 416 | + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), |
| 417 | + True, |
| 418 | + cebra.data.Offset(0, 1), |
| 419 | + 1, |
| 420 | + 3, |
| 421 | + torch.tensor([[4, 5, 6], [7, 8, 9]]), |
| 422 | + ), |
| 423 | + ( |
| 424 | + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], |
| 425 | + [13, 14, 15]]), |
| 426 | + True, |
| 427 | + cebra.data.Offset(1, 2), |
| 428 | + 1, |
| 429 | + 3, |
| 430 | + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]), |
| 431 | + ), |
| 432 | +
|
| 433 | + # Test case 4: Middle batch with padding |
| 434 | + ( |
| 435 | + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]), |
| 436 | + True, |
| 437 | + cebra.data.Offset(0, 1), |
| 438 | + 1, |
| 439 | + 3, |
| 440 | + torch.tensor([[4, 5, 6], [7, 8, 9]]), |
| 441 | + ), |
| 442 | + ( |
| 443 | + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]), |
| 444 | + True, |
| 445 | + cebra.data.Offset(1, 1), |
| 446 | + 1, |
| 447 | + 3, |
| 448 | + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), |
| 449 | + ), |
| 450 | + ( |
| 451 | + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], |
| 452 | + [13, 14, 15]]), |
| 453 | + True, |
| 454 | + cebra.data.Offset(0, 1), |
| 455 | + 2, |
| 456 | + 4, |
| 457 | + torch.tensor([[7, 8, 9], [10, 11, 12]]), |
| 458 | + ), |
| 459 | + ( |
| 460 | + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]), |
| 461 | + True, |
| 462 | + cebra.data.Offset(0, 1), |
| 463 | + 0, |
| 464 | + 3, |
| 465 | + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), |
| 466 | + ), |
| 467 | +
|
| 468 | + # Examples that throw an error: |
| 469 | +
|
| 470 | + # Padding without offset (should raise an error) |
| 471 | + (torch.tensor([[1, 2]]), True, None, 0, 2, ValueError), |
| 472 | + # Negative start_batch_idx or end_batch_idx (should raise an error) |
| 473 | + (torch.tensor([[1, 2]]), False, cebra.data.Offset( |
| 474 | + 0, 1), -1, 2, ValueError), |
| 475 | + # out of bound indices because offset is too large |
| 476 | + (torch.tensor([[1, 2], [3, 4]]), True, cebra.data.Offset( |
| 477 | + 5, 5), 1, 2, ValueError), |
| 478 | + # Batch length is smaller than offset. |
| 479 | + (torch.tensor([[1, 2], [3, 4]]), False, cebra.data.Offset( |
| 480 | + 0, 1), 0, 1, ValueError), # first batch |
| 481 | + ], |
| 482 | +) |
| 483 | +def test_get_batch(inputs, add_padding, offset, start_batch_idx, end_batch_idx, |
| 484 | + expected_output): |
| 485 | + if expected_output == ValueError: |
| 486 | + with pytest.raises(ValueError): |
| 487 | + cebra.solver.base._get_batch(inputs, offset, start_batch_idx, |
| 488 | + end_batch_idx, add_padding) |
| 489 | + else: |
| 490 | + result = cebra.solver.base._get_batch(inputs, offset, start_batch_idx, |
| 491 | + end_batch_idx, add_padding) |
| 492 | + assert torch.equal(result, expected_output) |
| 493 | + |
| 494 | + |
| 495 | +@pytest.mark.parametrize("data_name, loader_initfunc, solver_initfunc", |
| 496 | + multi_session_tests) |
| 497 | +def test_multi_session_2(data_name, loader_initfunc, solver_initfunc): |
| 498 | + loader = _get_loader(data_name, loader_initfunc) |
| 499 | + criterion = cebra.models.InfoNCE() |
| 500 | + model = nn.ModuleList( |
| 501 | + [_make_model(dataset) for dataset in loader.dataset.iter_sessions()]) |
| 502 | + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) |
| 503 | + |
| 504 | + solver = solver_initfunc(model=model, |
| 505 | + criterion=criterion, |
| 506 | + optimizer=optimizer, |
| 507 | + tqdm_on=True) |
| 508 | + |
| 509 | + batch = next(iter(loader)) |
| 510 | + for session_id, dataset in enumerate(loader.dataset.iter_sessions()): |
| 511 | + assert batch[session_id].reference.shape == (32, |
| 512 | + dataset.input_dimension, |
| 513 | + 10) |
| 514 | + assert batch[session_id].index is not None |
| 515 | + |
| 516 | + log = solver.step(batch) |
| 517 | + assert isinstance(log, dict) |
| 518 | + |
| 519 | + solver.fit(loader) |
| 520 | + |
| 521 | + assert solver.num_sessions == 3 |
| 522 | + assert solver.n_features == [X[i].shape[1] for i in range(len(X))] |
| 523 | + |
323 | 524 | embedding = solver.transform(X[0], session_id=0) |
324 | 525 | assert isinstance(embedding, torch.Tensor) |
325 | 526 | assert embedding.shape == (X[0].shape[0], OUTPUT_DIMENSION) |
@@ -504,8 +705,8 @@ def create_model(model_name, input_dimension): |
504 | 705 |
|
505 | 706 | @pytest.mark.parametrize( |
506 | 707 | "data_name, model_name ,session_id, loader_initfunc, solver_initfunc", |
507 | | - single_session_tests_select_model + single_session_hybrid_tests_select_model |
508 | | -) |
| 708 | + single_session_tests_select_model + |
| 709 | + single_session_hybrid_tests_select_model) |
509 | 710 | def test_select_model_single_session(data_name, model_name, session_id, |
510 | 711 | loader_initfunc, solver_initfunc): |
511 | 712 | dataset = cebra.datasets.init(data_name) |
@@ -576,7 +777,7 @@ def test_select_model_multi_session(data_name, model_name, session_id, |
576 | 777 | "offset40-model-4x-subsample", |
577 | 778 | "offset1-model", |
578 | 779 | "offset10-model", |
579 | | -] |
| 780 | +] #NOTE(rodrigo): there is an issue with "offset4-model-2x-subsample" because it's not a convolutional model. |
580 | 781 | batch_size_inference = [40_000, 99_990, 99_999] |
581 | 782 |
|
582 | 783 | single_session_tests_transform = [] |
|
0 commit comments