1818
1919
2020class TestGetters (unittest .TestCase ):
21- def helper (self , datasets , src_class , target_class , sizes ):
21+ def helper (
22+ self ,
23+ datasets ,
24+ src_class ,
25+ target_class ,
26+ sizes ,
27+ target_with_labels = False ,
28+ supervised = False ,
29+ ):
2230 for k in ["src_train" , "src_val" ]:
2331 self .assertTrue (isinstance (datasets [k ].dataset .datasets [0 ], src_class ))
32+ self .assertTrue (isinstance (datasets [k ], SourceDataset ))
2433 self .assertTrue (len (datasets [k ]) == sizes [k ])
25- for k in ["target_train" , "target_val" ]:
34+
35+ target_splits = ["target_train" , "target_val" ]
36+ if target_with_labels :
37+ target_splits += ["target_train_with_labels" , "target_val_with_labels" ]
38+ target_sizes = {k for k in sizes .keys () if k .startswith ("target" )}
39+
40+ self .assertTrue (set (target_splits ) == target_sizes )
41+ for k in target_splits :
2642 self .assertTrue (isinstance (datasets [k ].dataset .datasets [0 ], target_class ))
43+ self .assertTrue (isinstance (datasets [k ], TargetDataset ))
2744 self .assertTrue (len (datasets [k ]) == sizes [k ])
45+ if supervised :
46+ # target_train and target_val will be supervised
47+ # and if target_with_labels is true, the with_labels will also be supervised
48+ self .assertTrue (datasets [k ].supervised )
49+ else :
50+ # otherwise, only the ones with labels are supervised
51+ self .assertTrue (datasets [k ].supervised == k .endswith ("with_labels" ))
2852
2953 @unittest .skipIf (not RUN_DATASET_TESTS , skip_reason )
3054 def test_empty_array (self ):
@@ -41,21 +65,30 @@ def test_empty_array(self):
4165 self .assertTrue (isinstance (datasets ["train" ], TargetDataset ))
4266
4367 @unittest .skipIf (not RUN_DATASET_TESTS , skip_reason )
44- def test_get_mnist_mnistm (self ):
45- datasets = get_mnist_mnistm (
46- ["mnist" ], ["mnistm" ], folder = DATASET_FOLDER , download = True
47- )
48- self .helper (
49- datasets ,
50- MNIST ,
51- MNISTM ,
52- {
53- "src_train" : 60000 ,
54- "src_val" : 10000 ,
55- "target_train" : 59001 ,
56- "target_val" : 9001 ,
57- },
58- )
68+ def test_get_mnist_mnistm_return_targets (self ):
69+ for target_with_labels in [False , True ]:
70+ for supervised in [False , True ]:
71+ datasets = get_mnist_mnistm (
72+ ["mnist" ],
73+ ["mnistm" ],
74+ folder = DATASET_FOLDER ,
75+ download = True ,
76+ return_target_with_labels = target_with_labels ,
77+ supervised = supervised ,
78+ )
79+ len_dict = {
80+ "src_train" : 60000 ,
81+ "src_val" : 10000 ,
82+ "target_train" : 59001 ,
83+ "target_val" : 9001 ,
84+ }
85+ if target_with_labels :
86+ len_dict ["target_train_with_labels" ] = len_dict ["target_train" ]
87+ len_dict ["target_val_with_labels" ] = len_dict ["target_val" ]
88+
89+ self .helper (
90+ datasets , MNIST , MNISTM , len_dict , target_with_labels , supervised
91+ )
5992
6093 @unittest .skipIf (not RUN_DATASET_TESTS , skip_reason )
6194 def test_officehome (self ):
0 commit comments