@@ -68,9 +68,9 @@ def test_demo():
6868
6969@pytest .mark .requires_dataset
7070def test_hippocampus ():
71- from cebra .datasets import hippocampus
72-
7371 pytest .skip ("Outdated" )
72+
73+ from cebra .datasets import hippocampus # noqa: F401
7474 dataset = cebra .datasets .init ("rat-hippocampus-single" )
7575 loader = cebra .data .ContinuousDataLoader (
7676 dataset = dataset ,
@@ -99,7 +99,7 @@ def test_hippocampus():
9999
100100@pytest .mark .requires_dataset
101101def test_monkey ():
102- from cebra .datasets import monkey_reaching
102+ from cebra .datasets import monkey_reaching # noqa: F401
103103
104104 dataset = cebra .datasets .init (
105105 "area2-bump-pos-active-passive" ,
@@ -111,7 +111,7 @@ def test_monkey():
111111
112112@pytest .mark .requires_dataset
113113def test_allen ():
114- from cebra .datasets import allen
114+ from cebra .datasets import allen # noqa: F401
115115
116116 pytest .skip ("Test takes too long" )
117117
@@ -148,7 +148,7 @@ def test_allen():
148148 multisubject_options .extend (
149149 cebra .datasets .get_options (
150150 "rat-hippocampus-multisubjects-3fold-trial-split*" ))
151- except :
151+ except : # noqa: E722
152152 options = []
153153
154154
@@ -388,3 +388,106 @@ def test_download_file_wrong_content_disposition(filename, url,
388388 expected_checksum = expected_checksum ,
389389 location = temp_dir ,
390390 file_name = filename )
391+
392+
393+ @pytest .mark .parametrize ("neural, continuous, discrete" , [
394+ (np .random .randn (100 , 30 ), np .random .randn (
395+ 100 , 2 ), np .random .randint (0 , 5 , (100 ,))),
396+ (np .random .randn (50 , 20 ), None , np .random .randint (0 , 3 , (50 ,))),
397+ (np .random .randn (200 , 40 ), np .random .randn (200 , 5 ), None ),
398+ ])
399+ def test_tensor_dataset_initialization (neural , continuous , discrete ):
400+ dataset = cebra .data .datasets .TensorDataset (neural ,
401+ continuous = continuous ,
402+ discrete = discrete )
403+ assert dataset .neural .shape == neural .shape
404+ if continuous is not None :
405+ assert dataset .continuous .shape == continuous .shape
406+ if discrete is not None :
407+ assert dataset .discrete .shape == discrete .shape
408+
409+
410+ def test_tensor_dataset_invalid_initialization ():
411+ neural = np .random .randn (100 , 30 )
412+ with pytest .raises (ValueError ):
413+ cebra .data .datasets .TensorDataset (neural )
414+
415+
416+ @pytest .mark .parametrize ("neural, continuous, discrete" , [
417+ (np .random .randn (100 , 30 ), np .random .randn (
418+ 100 , 2 ), np .random .randint (0 , 5 , (100 ,))),
419+ (np .random .randn (50 , 20 ), None , np .random .randint (0 , 3 , (50 ,))),
420+ (np .random .randn (200 , 40 ), np .random .randn (200 , 5 ), None ),
421+ ])
422+ def test_tensor_dataset_length (neural , continuous , discrete ):
423+ dataset = cebra .data .datasets .TensorDataset (neural ,
424+ continuous = continuous ,
425+ discrete = discrete )
426+ assert len (dataset ) == len (neural )
427+
428+
429+ @pytest .mark .parametrize ("neural, continuous, discrete" , [
430+ (np .random .randn (100 , 30 ), np .random .randn (
431+ 100 , 2 ), np .random .randint (0 , 5 , (100 ,))),
432+ (np .random .randn (50 , 20 ), None , np .random .randint (0 , 3 , (50 ,))),
433+ (np .random .randn (200 , 40 ), np .random .randn (200 , 5 ), None ),
434+ ])
435+ def test_tensor_dataset_getitem (neural , continuous , discrete ):
436+ dataset = cebra .data .datasets .TensorDataset (neural ,
437+ continuous = continuous ,
438+ discrete = discrete )
439+ index = torch .randint (0 , len (dataset ), (10 ,))
440+ batch = dataset [index ]
441+ assert batch .shape [0 ] == len (index )
442+ assert batch .shape [1 ] == neural .shape [1 ]
443+
444+
445+ def test_tensor_dataset_invalid_discrete_type ():
446+ neural = np .random .randn (100 , 30 )
447+ continuous = np .random .randn (100 , 2 )
448+ discrete = np .random .randn (100 , 2 ) # Invalid type: float instead of int
449+ with pytest .raises (TypeError ):
450+ cebra .data .datasets .TensorDataset (neural ,
451+ continuous = continuous ,
452+ discrete = discrete )
453+
454+
455+ @pytest .mark .parametrize ("array, check_dtype, expected_dtype" , [
456+ (np .random .randn (100 , 30 ), "float" , torch .float32 ),
457+ (np .random .randint (0 , 5 , (100 , 30 )), "int" , torch .int64 ),
458+ (torch .randn (100 , 30 ), "float" , torch .float32 ),
459+ (torch .randint (0 , 5 , (100 , 30 )), "int" , torch .int64 ),
460+ (None , None , None ),
461+ ])
462+ def test_to_tensor (array , check_dtype , expected_dtype ):
463+ dataset = cebra .data .datasets .TensorDataset (np .random .randn (10 , 2 ),
464+ continuous = np .random .randn (
465+ 10 , 2 ))
466+ result = dataset ._to_tensor (array , check_dtype = check_dtype )
467+ if array is None :
468+ assert result is None
469+ else :
470+ assert isinstance (result , torch .Tensor )
471+ assert result .dtype == expected_dtype
472+
473+
474+ def test_to_tensor_invalid_dtype ():
475+ dataset = cebra .data .datasets .TensorDataset (np .random .randn (10 , 2 ),
476+ continuous = np .random .randn (
477+ 10 , 2 ))
478+ array = np .random .randn (100 , 30 )
479+ with pytest .raises (TypeError ):
480+ dataset ._to_tensor (array , check_dtype = "int" )
481+ array = np .random .randint (0 , 5 , (100 , 30 ))
482+ with pytest .raises (TypeError ):
483+ dataset ._to_tensor (array , check_dtype = "float" )
484+
485+
486+ def test_to_tensor_invalid_check_dtype ():
487+ dataset = cebra .data .datasets .TensorDataset (np .random .randn (10 , 2 ),
488+ continuous = np .random .randn (
489+ 10 , 2 ))
490+ array = np .random .randn (100 , 30 )
491+ with pytest .raises (ValueError ,
492+ match = "check_dtype must be 'int' or 'float', got" ):
493+ dataset ._to_tensor (array , check_dtype = "invalid_dtype" )
0 commit comments