@@ -305,6 +305,21 @@ def _iter_property(self, attr):
305305
306306# TODO(stes): This should be a single session dataset?
307307class DatasetxCEBRA (cebra_io .HasDevice ):
308+ """Dataset class for xCEBRA models.
309+
310+ This class handles neural data and associated labels for xCEBRA models, providing
311+ functionality for data loading and batch preparation.
312+
313+ Attributes:
314+ neural: Neural data as a torch.Tensor or numpy array
315+ labels: Labels associated with the data
316+ offset: Offset for the dataset
317+
318+ Args:
319+ neural: Neural data as a torch.Tensor or numpy array
320+ device: Device to store the data on (default: "cpu")
321+ **labels: Additional keyword arguments for labels associated with the data
322+ """
308323
309324 def __init__ (
310325 self ,
@@ -315,12 +330,23 @@ def __init__(
315330 super ().__init__ (device )
316331 self .neural = neural
317332 self .labels = labels
333+ self .offset = Offset (0 , 1 )
318334
319335 @property
320336 def input_dimension (self ) -> int :
337+ """Get the input dimension of the neural data.
338+
339+ Returns:
340+ The number of features in the neural data
341+ """
321342 return self .neural .shape [1 ]
322343
323344 def __len__ (self ):
345+ """Get the length of the dataset.
346+
347+ Returns:
348+ Number of samples in the dataset
349+ """
324350 return len (self .neural )
325351
326352 def configure_for (self , model : "Model" ):
@@ -335,7 +361,8 @@ def configure_for(self, model: "Model"):
335361 self .offset = model .get_offset ()
336362
337363 def expand_index (self , index : torch .Tensor ) -> torch .Tensor :
338- """
364+ """Expand indices based on the configured offset.
365+
339366 Args:
340367 index: A one-dimensional tensor of type long containing indices
341368 to select from the dataset.
@@ -359,11 +386,28 @@ def expand_index(self, index: torch.Tensor) -> torch.Tensor:
359386 return index [:, None ] + offset [None , :]
360387
361388 def __getitem__ (self , index ):
389+ """Get item(s) from the dataset at the specified index.
390+
391+ Args:
392+ index: Index or indices to retrieve
393+
394+ Returns:
395+ The neural data at the specified indices, with dimensions transposed
396+ """
362397 index = self .expand_index (index )
363398 return self .neural [index ].transpose (2 , 1 )
364399
365400 def load_batch_supervised (self , index : Batch ,
366401 labels_supervised ) -> torch .tensor :
402+ """Load a batch for supervised learning.
403+
404+ Args:
405+ index: Batch indices for reference data
406+ labels_supervised: Labels to load for supervised learning
407+
408+ Returns:
409+ Batch containing reference data and corresponding labels
410+ """
367411 assert index .negative is None
368412 assert index .positive is None
369413 labels = [
@@ -377,6 +421,14 @@ def load_batch_supervised(self, index: Batch,
377421 )
378422
379423 def load_batch_contrastive (self , index : BatchIndex ) -> Batch :
424+ """Load a batch for contrastive learning.
425+
426+ Args:
427+ index: BatchIndex containing reference, positive and negative indices
428+
429+ Returns:
430+ Batch containing reference, positive and negative samples
431+ """
380432 assert isinstance (index .positive , list )
381433 return Batch (
382434 reference = self [index .reference ],
0 commit comments