Skip to content

Commit 92bc813

Browse files
committed
🐛 Fix test_datset architecture
1 parent 94a747f commit 92bc813

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

tests/models/test_dataset.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,9 @@ def test_kather_dataset(tmp_path: Path) -> None:
120120
assert len(dataset.inputs) == len(dataset.labels)
121121

122122
# to actually get the image, we feed it to PatchDataset
123-
actual_ds = PatchDataset(dataset.inputs, dataset.labels)
123+
actual_ds = PatchDataset(
124+
dataset.inputs, dataset.labels, patch_input_shape=(224, 224)
125+
)
124126
sample_patch = actual_ds[89]
125127
assert isinstance(sample_patch["image"], np.ndarray)
126128
assert sample_patch["label"] is not None
@@ -136,7 +138,9 @@ def test_patch_dataset_path_imgs(
136138
"""Test for patch dataset with a list of file paths as input."""
137139
size = (224, 224, 3)
138140

139-
dataset = PatchDataset([Path(sample_patch1), Path(sample_patch2)])
141+
dataset = PatchDataset(
142+
[Path(sample_patch1), Path(sample_patch2)], patch_input_shape=size[:-1]
143+
)
140144

141145
for _, sample_data in enumerate(dataset):
142146
sampled_img_shape = sample_data["image"].shape
@@ -152,7 +156,7 @@ def test_patch_dataset_list_imgs(tmp_path: Path) -> None:
152156
size = (5, 5, 3)
153157
img = RNG.integers(low=0, high=255, size=size)
154158
list_imgs = [img, img, img]
155-
dataset = PatchDataset(list_imgs)
159+
dataset = PatchDataset(list_imgs, patch_input_shape=size[:-1])
156160

157161
dataset.preproc_func = lambda x: x
158162

@@ -197,14 +201,14 @@ def test_patch_datasetarray_imgs() -> None:
197201
array_imgs = np.array(list_imgs)
198202

199203
# test different setter for label
200-
dataset = PatchDataset(array_imgs, labels=labels)
204+
dataset = PatchDataset(array_imgs, labels=labels, patch_input_shape=(224, 224))
201205
an_item = dataset[2]
202206
assert an_item["label"] == 3
203-
dataset = PatchDataset(array_imgs, labels=None)
207+
dataset = PatchDataset(array_imgs, labels=None, patch_input_shape=(224, 224))
204208
an_item = dataset[2]
205209
assert "label" not in an_item
206210

207-
dataset = PatchDataset(array_imgs)
211+
dataset = PatchDataset(array_imgs, patch_input_shape=size[:-1])
208212
for _, sample_data in enumerate(dataset):
209213
sampled_img_shape = sample_data["image"].shape
210214
assert sampled_img_shape[0] == size[0]

0 commit comments

Comments
 (0)