Skip to content

Commit 61cb9b7

Browse files
committed
fix docs build, missing refs
1 parent e20fda1 commit 61cb9b7

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

cebra/data/datasets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ def __getitem__(self, index):
398398
return self.neural[index].transpose(2, 1)
399399

400400
def load_batch_supervised(self, index: Batch,
401-
labels_supervised) -> torch.tensor:
401+
labels_supervised) -> torch.Tensor:
402402
"""Load a batch for supervised learning.
403403
404404
Args:

cebra/data/single_session.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -172,10 +172,9 @@ class ContinuousDataLoader(cebra_data.Loader):
172172
* auxiliary variables, using the empirical distribution of how behavior various across
173173
``time_offset`` timesteps (``time_delta``). Sampling for this setting is implemented
174174
in :py:class:`cebra.distributions.continuous.TimedeltaDistribution`.
175-
* alternatively, the distribution can be selected to be a Gaussian or von Mises-Fisher distribution
175+
* alternatively, the distribution can be selected to be a Gaussian distribution
176176
parametrized by a fixed ``delta`` around the reference sample, using the implementation in
177-
:py:class:`cebra.distributions.continuous.DeltaNormalDistribution` and
178-
:py:class:`cebra.distributions.continuous.DeltaVMFDistribution`.
177+
:py:class:`cebra.distributions.continuous.DeltaNormalDistribution`.
179178
180179
Args:
181180
See dataclass fields.
@@ -228,11 +227,13 @@ def _init_distribution(self):
228227
self.dataset.continuous_index,
229228
self.delta,
230229
device=self.device)
231-
elif self.conditional == "delta_vmf":
232-
self.distribution = cebra.distributions.DeltaVMFDistribution(
233-
self.dataset.continuous_index,
234-
self.delta,
235-
device=self.device)
230+
# TODO(stes): Add this distribution from internal xCEBRA codebase at a later point
231+
# in time, currently not in use.
232+
#elif self.conditional == "delta_vmf":
233+
# self.distribution = cebra.distributions.DeltaVMFDistribution(
234+
# self.dataset.continuous_index,
235+
# self.delta,
236+
# device=self.device)
236237
else:
237238
raise ValueError(self.conditional)
238239

0 commit comments

Comments
 (0)