@@ -120,7 +120,9 @@ def test_kather_dataset(tmp_path: Path) -> None:
120120 assert len (dataset .inputs ) == len (dataset .labels )
121121
122122 # to actually get the image, we feed it to PatchDataset
123- actual_ds = PatchDataset (dataset .inputs , dataset .labels )
123+ actual_ds = PatchDataset (
124+ dataset .inputs , dataset .labels , patch_input_shape = (224 , 224 )
125+ )
124126 sample_patch = actual_ds [89 ]
125127 assert isinstance (sample_patch ["image" ], np .ndarray )
126128 assert sample_patch ["label" ] is not None
@@ -136,7 +138,9 @@ def test_patch_dataset_path_imgs(
136138 """Test for patch dataset with a list of file paths as input."""
137139 size = (224 , 224 , 3 )
138140
139- dataset = PatchDataset ([Path (sample_patch1 ), Path (sample_patch2 )])
141+ dataset = PatchDataset (
142+ [Path (sample_patch1 ), Path (sample_patch2 )], patch_input_shape = size [:- 1 ]
143+ )
140144
141145 for _ , sample_data in enumerate (dataset ):
142146 sampled_img_shape = sample_data ["image" ].shape
@@ -152,7 +156,7 @@ def test_patch_dataset_list_imgs(tmp_path: Path) -> None:
152156 size = (5 , 5 , 3 )
153157 img = RNG .integers (low = 0 , high = 255 , size = size )
154158 list_imgs = [img , img , img ]
155- dataset = PatchDataset (list_imgs )
159+ dataset = PatchDataset (list_imgs , patch_input_shape = size [: - 1 ] )
156160
157161 dataset .preproc_func = lambda x : x
158162
@@ -197,14 +201,14 @@ def test_patch_datasetarray_imgs() -> None:
197201 array_imgs = np .array (list_imgs )
198202
199203 # test different setter for label
200- dataset = PatchDataset (array_imgs , labels = labels )
204+ dataset = PatchDataset (array_imgs , labels = labels , patch_input_shape = ( 224 , 224 ) )
201205 an_item = dataset [2 ]
202206 assert an_item ["label" ] == 3
203- dataset = PatchDataset (array_imgs , labels = None )
207+ dataset = PatchDataset (array_imgs , labels = None , patch_input_shape = ( 224 , 224 ) )
204208 an_item = dataset [2 ]
205209 assert "label" not in an_item
206210
207- dataset = PatchDataset (array_imgs )
211+ dataset = PatchDataset (array_imgs , patch_input_shape = size [: - 1 ] )
208212 for _ , sample_data in enumerate (dataset ):
209213 sampled_img_shape = sample_data ["image" ].shape
210214 assert sampled_img_shape [0 ] == size [0 ]
0 commit comments