@@ -178,8 +178,7 @@ def num_sessions(self) -> int:
178178 """The number of sessions in the index."""
179179 return len (self .lengths )
180180
181- def mix (self , array : "npt.NDArray[np.float64]" ,
182- idx : "npt.NDArray[np.int64]" ):
181+ def mix (self , array : npt .NDArray , idx : npt .NDArray ):
183182 """Re-order array elements according to the given index mapping.
184183
185184 The given array should be of the shape ``(session, batch, ...)`` and the
@@ -439,8 +438,7 @@ class UnifiedSampler(MultisessionSampler):
439438
440439 """
441440
442- def sample_all_uniform_prior (self ,
443- num_samples : int ) -> "npt.NDArray[np.int64]" :
441+ def sample_all_uniform_prior (self , num_samples : int ) -> npt .NDArray :
444442 """Returns uniformly sampled index for all sessions of the dataset.
445443
446444 Args:
@@ -452,10 +450,9 @@ def sample_all_uniform_prior(self,
452450 """
453451 return super ().sample_prior (num_samples = num_samples )
454452
455- def sample_prior (
456- self ,
457- num_samples : int ,
458- session_id : Optional [int ] = None ) -> "npt.NDArray[np.int64]" :
453+ def sample_prior (self ,
454+ num_samples : int ,
455+ session_id : Optional [int ] = None ) -> npt .NDArray :
459456 """Return uniformly sampled indices for all sessions.
460457
461458 First, the reference indexes in a reference session are uniformly sampled.
@@ -558,8 +555,7 @@ def sample_all_sessions(self, ref_idx: torch.Tensor,
558555 all_idx [session_id ] = ref_idx
559556 return all_idx
560557
561- def sample_conditional (
562- self , reference_idx : "npt.NDArray[np.int64]" ) -> torch .Tensor :
558+ def sample_conditional (self , reference_idx : npt .NDArray ) -> torch .Tensor :
563559 """Sample from the conditional distribution.
564560
565561 Contrary to the :py:class:`MultisessionSampler`, conditional distribution
0 commit comments