Skip to content

Commit a968768

Browse files
committed
Remove np.int typing error
1 parent 4424ba1 commit a968768

File tree

1 file changed

+6
-10
lines changed

1 file changed

+6
-10
lines changed

cebra/distributions/multisession.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,7 @@ def num_sessions(self) -> int:
178178
"""The number of sessions in the index."""
179179
return len(self.lengths)
180180

181-
def mix(self, array: "npt.NDArray[np.float64]",
182-
idx: "npt.NDArray[np.int64]"):
181+
def mix(self, array: npt.NDArray, idx: npt.NDArray):
183182
"""Re-order array elements according to the given index mapping.
184183
185184
The given array should be of the shape ``(session, batch, ...)`` and the
@@ -439,8 +438,7 @@ class UnifiedSampler(MultisessionSampler):
439438
440439
"""
441440

442-
def sample_all_uniform_prior(self,
443-
num_samples: int) -> "npt.NDArray[np.int64]":
441+
def sample_all_uniform_prior(self, num_samples: int) -> npt.NDArray:
444442
"""Returns uniformly sampled index for all sessions of the dataset.
445443
446444
Args:
@@ -452,10 +450,9 @@ def sample_all_uniform_prior(self,
452450
"""
453451
return super().sample_prior(num_samples=num_samples)
454452

455-
def sample_prior(
456-
self,
457-
num_samples: int,
458-
session_id: Optional[int] = None) -> "npt.NDArray[np.int64]":
453+
def sample_prior(self,
454+
num_samples: int,
455+
session_id: Optional[int] = None) -> npt.NDArray:
459456
"""Return uniformly sampled indices for all sessions.
460457
461458
First, the reference indexes in a reference session are uniformly sampled.
@@ -558,8 +555,7 @@ def sample_all_sessions(self, ref_idx: torch.Tensor,
558555
all_idx[session_id] = ref_idx
559556
return all_idx
560557

561-
def sample_conditional(
562-
self, reference_idx: "npt.NDArray[np.int64]") -> torch.Tensor:
558+
def sample_conditional(self, reference_idx: npt.NDArray) -> torch.Tensor:
563559
"""Sample from the conditional distribution.
564560
565561
Contrary to the :py:class:`MultisessionSampler`, conditional distribution

0 commit comments

Comments
 (0)