Skip to content

Commit 4424ba1

Browse files
committed
Fix docs errors
1 parent 4d5e9c3 commit 4424ba1

File tree

3 files changed

+19
-9
lines changed

3 files changed

+19
-9
lines changed

cebra/distributions/multisession.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,8 @@ def num_sessions(self) -> int:
178178
"""The number of sessions in the index."""
179179
return len(self.lengths)
180180

181-
def mix(self, array: np.ndarray, idx: np.ndarray):
181+
def mix(self, array: "npt.NDArray[np.float64]",
182+
idx: "npt.NDArray[np.int64]"):
182183
"""Re-order array elements according to the given index mapping.
183184
184185
The given array should be of the shape ``(session, batch, ...)`` and the
@@ -439,7 +440,7 @@ class UnifiedSampler(MultisessionSampler):
439440
"""
440441

441442
def sample_all_uniform_prior(self,
442-
num_samples: int) -> npt.NDArray[np.int64]:
443+
num_samples: int) -> "npt.NDArray[np.int64]":
443444
"""Returns uniformly sampled index for all sessions of the dataset.
444445
445446
Args:
@@ -451,9 +452,10 @@ def sample_all_uniform_prior(self,
451452
"""
452453
return super().sample_prior(num_samples=num_samples)
453454

454-
def sample_prior(self,
455-
num_samples: int,
456-
session_id: Optional[int] = None) -> npt.NDArray[np.int64]:
455+
def sample_prior(
456+
self,
457+
num_samples: int,
458+
session_id: Optional[int] = None) -> "npt.NDArray[np.int64]":
457459
"""Return uniformly sampled indices for all sessions.
458460
459461
First, the reference indexes in a reference session are uniformly sampled.
@@ -506,7 +508,7 @@ def sample_all_sessions(self, ref_idx: torch.Tensor,
506508
session_id: int) -> torch.Tensor:
507509
"""Sample sessions based on a reference session.
508510
509-
Reference samples for the ``(session_id)``th session were first sampled uniformly, as in
511+
Reference samples for the ``session_id``th session were first sampled uniformly, as in
510512
the py:class:`~.MultisessionSampler`. Then, reference samples for the other sessions
511513
are sampled so that they are as close as the corresponding auxiliary variables in
512514
the reference session.
@@ -516,7 +518,7 @@ def sample_all_sessions(self, ref_idx: torch.Tensor,
516518
517519
Args:
518520
ref_idx: Uniformly sampled ``idx`` for the reference session, ``(num_samples, )``, values
519-
can be in ``[0, len(get_session[session_id])]``.
521+
can be in ``[0, len(session)]``.
520522
session_id: Session ID of the reference session, whose ``idx`` are present in ``ref_idx``.
521523
522524
Returns:
@@ -557,7 +559,7 @@ def sample_all_sessions(self, ref_idx: torch.Tensor,
557559
return all_idx
558560

559561
def sample_conditional(
560-
self, reference_idx: npt.NDArray[np.int64]) -> torch.Tensor:
562+
self, reference_idx: "npt.NDArray[np.int64]") -> torch.Tensor:
561563
"""Sample from the conditional distribution.
562564
563565
Contrary to the :py:class:`MultisessionSampler`, conditional distribution

docs/source/api/pytorch/helpers.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,11 @@ Data helpers
3535
.. automodule:: cebra.data.helper
3636
:members:
3737
:show-inheritance:
38+
39+
40+
Masking helpers
41+
----------------
42+
43+
.. automodule:: cebra.data.masking
44+
:members:
45+
:show-inheritance:

tests/test_data_masking.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def test_apply_mask_with_invalid_input():
202202

203203
with pytest.raises(ValueError, match="Data must be a 3D tensor"):
204204
data = torch.ones(
205-
(10, 20)) # Invalid tensor shape (missing offset dimension)
205+
(10, 20, 30, 40)) # Invalid tensor shape (extra dimension)
206206
mixin.apply_mask(data)
207207

208208
with pytest.raises(ValueError, match="Data must be a float32 tensor"):

0 commit comments

Comments
 (0)