We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 723bcfb commit dbabb6eCopy full SHA for dbabb6e
cebra/data/multi_session.py
@@ -160,7 +160,9 @@ def __post_init__(self):
160
161
# NOTE(stes): In the longer run, we need to unify the API here; the num_samples argument
162
# is not used in the multi-session case, which is different to the single session samples.
163
- def get_indices(self, num_samples) -> List[BatchIndex]:
+ def get_indices(self,
164
+ num_samples: int,
165
+ num_negatives: int = None) -> List[BatchIndex]:
166
ref_idx = self.sampler.sample_prior(self.batch_size)
167
neg_idx = self.sampler.sample_prior(self.num_negatives)
168
pos_idx, idx, idx_rev = self.sampler.sample_conditional(ref_idx)
0 commit comments