Skip to content

Commit dbabb6e

Browse files
committed
fix missing arg
1 parent 723bcfb commit dbabb6e

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

cebra/data/multi_session.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,9 @@ def __post_init__(self):
160160

161161
# NOTE(stes): In the longer run, we need to unify the API here; the num_samples argument
162162
# is not used in the multi-session case, which is different to the single session samples.
163-
def get_indices(self, num_samples) -> List[BatchIndex]:
163+
def get_indices(self,
164+
num_samples: int,
165+
num_negatives: int = None) -> List[BatchIndex]:
164166
ref_idx = self.sampler.sample_prior(self.batch_size)
165167
neg_idx = self.sampler.sample_prior(self.num_negatives)
166168
pos_idx, idx, idx_rev = self.sampler.sample_conditional(ref_idx)

0 commit comments

Comments
 (0)