2727
2828import abc
2929import warnings
30+ from typing import Iterator
3031
3132import literate_dataclasses as dataclasses
3233import torch
@@ -138,9 +139,7 @@ def _init_distribution(self):
138139 f"Invalid choice of prior distribution. Got '{ self .prior } ', but "
139140 f"only accept 'uniform' or 'empirical' as potential values." )
140141
141- def get_indices (self ,
142- num_samples : int ,
143- num_negatives : int = None ) -> BatchIndex :
142+ def get_indices (self ) -> BatchIndex :
144143 """Samples indices for reference, positive and negative examples.
145144
146145 The reference samples will be sampled from the empirical or uniform prior
@@ -161,13 +160,10 @@ def get_indices(self,
161160 Returns:
162161 Indices for reference, positive and negatives samples.
163162 """
164- if num_negatives is None :
165- num_negatives = num_samples
166-
167- reference_idx = self .distribution .sample_prior (num_samples +
168- num_negatives )
169- negative_idx = reference_idx [num_samples :]
170- reference_idx = reference_idx [:num_samples ]
163+ reference_idx = self .distribution .sample_prior (self .batch_size +
164+ self .num_negatives )
165+ negative_idx = reference_idx [self .batch_size :]
166+ reference_idx = reference_idx [:self .batch_size ]
171167 reference = self .index [reference_idx ]
172168 positive_idx = self .distribution .sample_conditional (reference )
173169 return BatchIndex (reference = reference_idx ,
@@ -253,9 +249,7 @@ def _init_distribution(self):
253249 else :
254250 raise ValueError (self .conditional )
255251
256- def get_indices (self ,
257- num_samples : int ,
258- num_negatives : int = None ) -> BatchIndex :
252+ def get_indices (self ) -> BatchIndex :
259253 """Samples indices for reference, positive and negative examples.
260254
261255 The reference and negative samples will be sampled uniformly from
@@ -271,13 +265,10 @@ def get_indices(self,
271265 Returns:
272266 Indices for reference, positive and negatives samples.
273267 """
274- if num_negatives is None :
275- num_negatives = num_samples
276-
277- reference_idx = self .distribution .sample_prior (num_samples +
278- num_negatives )
279- negative_idx = reference_idx [num_samples :]
280- reference_idx = reference_idx [:num_samples ]
268+ reference_idx = self .distribution .sample_prior (self .batch_size +
269+ self .num_negatives )
270+ negative_idx = reference_idx [self .batch_size :]
271+ reference_idx = reference_idx [:self .batch_size ]
281272 positive_idx = self .distribution .sample_conditional (reference_idx )
282273 return BatchIndex (reference = reference_idx ,
283274 positive = positive_idx ,
@@ -318,9 +309,7 @@ def __post_init__(self):
318309 continuous = self .cindex ,
319310 time_delta = self .time_offset )
320311
321- def get_indices (self ,
322- num_samples : int ,
323- num_negatives : int = None ) -> BatchIndex :
312+ def get_indices (self ) -> BatchIndex :
324313 """Samples indices for reference, positive and negative examples.
325314
326315 The reference and negative samples will be sampled uniformly from
@@ -344,13 +333,10 @@ def get_indices(self,
344333 class.
345334 - Sample the negatives with matching discrete variable
346335 """
347- if num_negatives is None :
348- num_negatives = num_samples
349-
350- reference_idx = self .distribution .sample_prior (num_samples +
351- num_negatives )
352- negative_idx = reference_idx [num_samples :]
353- reference_idx = reference_idx [:num_samples ]
336+ reference_idx = self .distribution .sample_prior (self .batch_size +
337+ self .num_negatives )
338+ negative_idx = reference_idx [self .batch_size :]
339+ reference_idx = reference_idx [:self .batch_size ]
354340 return BatchIndex (
355341 reference = reference_idx ,
356342 negative = negative_idx ,
@@ -443,9 +429,7 @@ def _init_time_distribution(self):
443429 else :
444430 raise ValueError
445431
446- def get_indices (self ,
447- num_samples : int ,
448- num_negatives : int = None ) -> BatchIndex :
432+ def get_indices (self ) -> BatchIndex :
449433 """Samples indices for reference, positive and negative examples.
450434
451435 The reference and negative samples will be sampled uniformly from
@@ -469,13 +453,10 @@ def get_indices(self,
469453 Add the ``empirical`` vs. ``discrete`` sampling modes to this
470454 class.
471455 """
472- if num_negatives is None :
473- num_negatives = num_samples
474-
475- reference_idx = self .time_distribution .sample_prior (num_samples +
476- num_negatives )
477- negative_idx = reference_idx [num_samples :]
478- reference_idx = reference_idx [:num_samples ]
456+ reference_idx = self .time_distribution .sample_prior (self .batch_size +
457+ self .num_negatives )
458+ negative_idx = reference_idx [self .batch_size :]
459+ reference_idx = reference_idx [:self .batch_size ]
479460 behavior_positive_idx = self .behavior_distribution .sample_conditional (
480461 reference_idx )
481462 time_positive_idx = self .time_distribution .sample_conditional (
@@ -493,13 +474,18 @@ class FullDataLoader(ContinuousDataLoader):
493474
494475 def __post_init__ (self ):
495476 super ().__post_init__ ()
496- self .batch_size = None
477+
478+ if self .batch_size is not None :
479+ raise ValueError ("Batch size cannot be set for FullDataLoader." )
480+ if self .num_negatives is not None :
481+ raise ValueError (
482+ "Number of negatives cannot be set for FullDataLoader." )
497483
498484 @property
499485 def offset (self ):
500486 return self .dataset .offset
501487
502- def get_indices (self , num_samples = None , num_negatives = None ) -> BatchIndex :
488+ def get_indices (self ) -> BatchIndex :
503489 """Samples indices for reference, positive and negative examples.
504490
505491 The reference indices are all available (valid, according to the
@@ -519,8 +505,6 @@ def get_indices(self, num_samples=None, num_negatives=None) -> BatchIndex:
519505 Add the ``empirical`` vs. ``discrete`` sampling modes to this
520506 class.
521507 """
522- assert num_samples is None
523- assert num_negatives is None
524508
525509 reference_idx = torch .arange (
526510 self .offset .left ,
@@ -534,7 +518,6 @@ def get_indices(self, num_samples=None, num_negatives=None) -> BatchIndex:
534518 positive = positive_idx ,
535519 negative = negative_idx )
536520
537- def __iter__ (self ):
521+ def __iter__ (self ) -> Iterator [ BatchIndex ] :
538522 for _ in range (len (self )):
539- index = self .get_indices (num_samples = self .batch_size )
540- yield index
523+ yield self .get_indices ()
0 commit comments