@@ -178,7 +178,8 @@ def num_sessions(self) -> int:
178178 """The number of sessions in the index."""
179179 return len (self .lengths )
180180
181- def mix (self , array : np .ndarray , idx : np .ndarray ):
181+ def mix (self , array : "npt.NDArray[np.float64]" ,
182+ idx : "npt.NDArray[np.int64]" ):
182183 """Re-order array elements according to the given index mapping.
183184
184185 The given array should be of the shape ``(session, batch, ...)`` and the
@@ -439,7 +440,7 @@ class UnifiedSampler(MultisessionSampler):
439440 """
440441
441442 def sample_all_uniform_prior (self ,
442- num_samples : int ) -> npt .NDArray [np .int64 ]:
443+ num_samples : int ) -> " npt.NDArray[np.int64]" :
443444 """Returns uniformly sampled index for all sessions of the dataset.
444445
445446 Args:
@@ -451,9 +452,10 @@ def sample_all_uniform_prior(self,
451452 """
452453 return super ().sample_prior (num_samples = num_samples )
453454
454- def sample_prior (self ,
455- num_samples : int ,
456- session_id : Optional [int ] = None ) -> npt .NDArray [np .int64 ]:
455+ def sample_prior (
456+ self ,
457+ num_samples : int ,
458+ session_id : Optional [int ] = None ) -> "npt.NDArray[np.int64]" :
457459 """Return uniformly sampled indices for all sessions.
458460
459461 First, the reference indexes in a reference session are uniformly sampled.
@@ -506,7 +508,7 @@ def sample_all_sessions(self, ref_idx: torch.Tensor,
506508 session_id : int ) -> torch .Tensor :
507509 """Sample sessions based on a reference session.
508510
509- Reference samples for the ``( session_id) ``th session were first sampled uniformly, as in
511+ Reference samples for the ``session_id``th session were first sampled uniformly, as in
510512 the py:class:`~.MultisessionSampler`. Then, reference samples for the other sessions
511513 are sampled so that they are as close as the corresponding auxiliary variables in
512514 the reference session.
@@ -516,7 +518,7 @@ def sample_all_sessions(self, ref_idx: torch.Tensor,
516518
517519 Args:
518520 ref_idx: Uniformly sampled ``idx`` for the reference session, ``(num_samples, )``, values
519- can be in ``[0, len(get_session[session_id] )]``.
521+ can be in ``[0, len(session )]``.
520522 session_id: Session ID of the reference session, whose ``idx`` are present in ``ref_idx``.
521523
522524 Returns:
@@ -557,7 +559,7 @@ def sample_all_sessions(self, ref_idx: torch.Tensor,
557559 return all_idx
558560
559561 def sample_conditional (
560- self , reference_idx : npt .NDArray [np .int64 ]) -> torch .Tensor :
562+ self , reference_idx : " npt.NDArray[np.int64]" ) -> torch .Tensor :
561563 """Sample from the conditional distribution.
562564
563565 Contrary to the :py:class:`MultisessionSampler`, conditional distribution
0 commit comments