Skip to content

Commit d81b93d

Browse files
committed
more docstrings to fix attr error
1 parent df4f661 commit d81b93d

File tree

1 file changed

+53
-1
lines changed

1 file changed

+53
-1
lines changed

cebra/data/datasets.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,21 @@ def _iter_property(self, attr):
305305

306306
# TODO(stes): This should be a single session dataset?
307307
class DatasetxCEBRA(cebra_io.HasDevice):
308+
"""Dataset class for xCEBRA models.
309+
310+
This class handles neural data and associated labels for xCEBRA models, providing
311+
functionality for data loading and batch preparation.
312+
313+
Attributes:
314+
neural: Neural data as a torch.Tensor or numpy array
315+
labels: Labels associated with the data
316+
offset: Offset for the dataset
317+
318+
Args:
319+
neural: Neural data as a torch.Tensor or numpy array
320+
device: Device to store the data on (default: "cpu")
321+
**labels: Additional keyword arguments for labels associated with the data
322+
"""
308323

309324
def __init__(
310325
self,
@@ -315,12 +330,23 @@ def __init__(
315330
super().__init__(device)
316331
self.neural = neural
317332
self.labels = labels
333+
self.offset = Offset(0, 1)
318334

319335
@property
320336
def input_dimension(self) -> int:
337+
"""Get the input dimension of the neural data.
338+
339+
Returns:
340+
The number of features in the neural data
341+
"""
321342
return self.neural.shape[1]
322343

323344
def __len__(self):
345+
"""Get the length of the dataset.
346+
347+
Returns:
348+
Number of samples in the dataset
349+
"""
324350
return len(self.neural)
325351

326352
def configure_for(self, model: "Model"):
@@ -335,7 +361,8 @@ def configure_for(self, model: "Model"):
335361
self.offset = model.get_offset()
336362

337363
def expand_index(self, index: torch.Tensor) -> torch.Tensor:
338-
"""
364+
"""Expand indices based on the configured offset.
365+
339366
Args:
340367
index: A one-dimensional tensor of type long containing indices
341368
to select from the dataset.
@@ -359,11 +386,28 @@ def expand_index(self, index: torch.Tensor) -> torch.Tensor:
359386
return index[:, None] + offset[None, :]
360387

361388
def __getitem__(self, index):
389+
"""Get item(s) from the dataset at the specified index.
390+
391+
Args:
392+
index: Index or indices to retrieve
393+
394+
Returns:
395+
The neural data at the specified indices, with dimensions transposed
396+
"""
362397
index = self.expand_index(index)
363398
return self.neural[index].transpose(2, 1)
364399

365400
def load_batch_supervised(self, index: Batch,
366401
labels_supervised) -> torch.tensor:
402+
"""Load a batch for supervised learning.
403+
404+
Args:
405+
index: Batch indices for reference data
406+
labels_supervised: Labels to load for supervised learning
407+
408+
Returns:
409+
Batch containing reference data and corresponding labels
410+
"""
367411
assert index.negative is None
368412
assert index.positive is None
369413
labels = [
@@ -377,6 +421,14 @@ def load_batch_supervised(self, index: Batch,
377421
)
378422

379423
def load_batch_contrastive(self, index: BatchIndex) -> Batch:
424+
"""Load a batch for contrastive learning.
425+
426+
Args:
427+
index: BatchIndex containing reference, positive and negative indices
428+
429+
Returns:
430+
Batch containing reference, positive and negative samples
431+
"""
380432
assert isinstance(index.positive, list)
381433
return Batch(
382434
reference=self[index.reference],

0 commit comments

Comments
 (0)