Skip to content

Commit 2b342f4

Browse files
committed
Merge remote-tracking branch 'origin/dev-define-semantic-segmentor' into dev-define-semantic-segmentor
2 parents facf461 + bddd956 commit 2b342f4

File tree

3 files changed

+226
-49
lines changed

3 files changed

+226
-49
lines changed

tests/models/test_arch_unet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def test_functional_unet(remote_sample: Callable) -> None:
5050
pretrained = torch.load(pretrained_weights, map_location="cpu")
5151
model.load_state_dict(pretrained)
5252
output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU))
53-
_ = output[0]
53+
_ = output["probabilities"][0]
5454

5555
# run untrained network to test for architecture
5656
model = UNetModel(

tests/models/test_dataset.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,9 @@ def test_kather_dataset(tmp_path: Path) -> None:
120120
assert len(dataset.inputs) == len(dataset.labels)
121121

122122
# to actually get the image, we feed it to PatchDataset
123-
actual_ds = PatchDataset(dataset.inputs, dataset.labels)
123+
actual_ds = PatchDataset(
124+
dataset.inputs, dataset.labels, patch_input_shape=(224, 224)
125+
)
124126
sample_patch = actual_ds[89]
125127
assert isinstance(sample_patch["image"], np.ndarray)
126128
assert sample_patch["label"] is not None
@@ -136,7 +138,9 @@ def test_patch_dataset_path_imgs(
136138
"""Test for patch dataset with a list of file paths as input."""
137139
size = (224, 224, 3)
138140

139-
dataset = PatchDataset([Path(sample_patch1), Path(sample_patch2)])
141+
dataset = PatchDataset(
142+
[Path(sample_patch1), Path(sample_patch2)], patch_input_shape=size[:-1]
143+
)
140144

141145
for _, sample_data in enumerate(dataset):
142146
sampled_img_shape = sample_data["image"].shape
@@ -152,7 +156,7 @@ def test_patch_dataset_list_imgs(tmp_path: Path) -> None:
152156
size = (5, 5, 3)
153157
img = RNG.integers(low=0, high=255, size=size)
154158
list_imgs = [img, img, img]
155-
dataset = PatchDataset(list_imgs)
159+
dataset = PatchDataset(list_imgs, patch_input_shape=size[:-1])
156160

157161
dataset.preproc_func = lambda x: x
158162

@@ -197,14 +201,14 @@ def test_patch_datasetarray_imgs() -> None:
197201
array_imgs = np.array(list_imgs)
198202

199203
# test different setter for label
200-
dataset = PatchDataset(array_imgs, labels=labels)
204+
dataset = PatchDataset(array_imgs, labels=labels, patch_input_shape=(224, 224))
201205
an_item = dataset[2]
202206
assert an_item["label"] == 3
203-
dataset = PatchDataset(array_imgs, labels=None)
207+
dataset = PatchDataset(array_imgs, labels=None, patch_input_shape=(224, 224))
204208
an_item = dataset[2]
205209
assert "label" not in an_item
206210

207-
dataset = PatchDataset(array_imgs)
211+
dataset = PatchDataset(array_imgs, patch_input_shape=size[:-1])
208212
for _, sample_data in enumerate(dataset):
209213
sampled_img_shape = sample_data["image"].shape
210214
assert sampled_img_shape[0] == size[0]

0 commit comments

Comments
 (0)