Skip to content

Commit 165d641

Browse files
committed
Implement review comments
1 parent 8798aa0 commit 165d641

File tree

11 files changed

+430
-475
lines changed

11 files changed

+430
-475
lines changed

cebra/data/datasets.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import cebra.data.masking as cebra_data_masking
3333
import cebra.helper as cebra_helper
3434
import cebra.io as cebra_io
35+
import cebra.models
3536
from cebra.data.datatypes import Batch
3637
from cebra.data.datatypes import BatchIndex
3738
from cebra.data.datatypes import Offset
@@ -475,6 +476,18 @@ def _get_batches(self, index):
475476
) for session_id in range(self.num_sessions)
476477
]
477478

479+
def configure_for(self, model: "cebra.models.Model"):
480+
"""Configure the dataset offset for the provided model.
481+
482+
Call this function before indexing the dataset. This sets the
483+
:py:attr:`~.Dataset.offset` attribute of the dataset.
484+
485+
Args:
486+
model: The model to configure the dataset for.
487+
"""
488+
for i, session in enumerate(self.iter_sessions()):
489+
session.configure_for(model)
490+
478491
def load_batch(self, index: BatchIndex) -> Batch:
479492
"""Return the data at the specified index location.
480493

cebra/data/mask.py

Lines changed: 0 additions & 342 deletions
This file was deleted.

0 commit comments

Comments
 (0)