@@ -186,6 +186,37 @@ def test_continuous(conditional, device, benchmark):
186186 benchmark (load_speed )
187187
188188
189+ @parametrize_device
190+ @pytest .mark .parametrize (
191+ "conditional, positive_sampling, discrete_sampling_prior" ,
192+ [
193+ ("time" , "discrete_variable" , "empirical" ),
194+ ("time" , "conditional" , "empirical" ),
195+ ("time" , "discrete_variable" , "uniform" ),
196+ ("time" , "conditional" , "uniform" ),
197+ ("time_delta" , "discrete_variable" , "empirical" ),
198+ ("time_delta" , "conditional" , "empirical" ),
199+ ("time_delta" , "discrete_variable" , "uniform" ),
200+ ("time_delta" , "conditional" , "uniform" ),
201+ ],
202+ )
203+ def test_mixed (
204+ conditional , positive_sampling , discrete_sampling_prior , device , benchmark
205+ ):
206+ dataset = RandomDataset (N = 100 , d = 5 , device = device )
207+ loader = cebra .data .MixedDataLoader (
208+ dataset = dataset ,
209+ num_steps = 10 ,
210+ batch_size = 8 ,
211+ conditional = conditional ,
212+ positive_sampling = positive_sampling ,
213+ discrete_sampling_prior = discrete_sampling_prior ,
214+ )
215+ _assert_dataset_on_correct_device (loader , device )
216+ load_speed = LoadSpeed (loader )
217+ benchmark (load_speed )
218+
219+
189220def _check_attributes (obj , is_list = False ):
190221 if is_list :
191222 for obj_ in obj :
0 commit comments