Skip to content

Commit 7f148f4

Browse files
committed
✅ Add test for dask length threshold
1 parent 0242512 commit 7f148f4

File tree

2 files changed

+30
-1
lines changed

2 files changed

+30
-1
lines changed

tests/engines/test_semantic_segmentor.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,32 @@ def test_wsi_segmentor_zarr(
246246
assert "count" not in output_
247247
assert "Current Memory usage:" in caplog.text
248248

249+
segmentor = SemanticSegmentor(
250+
model="fcn-tissue_mask",
251+
batch_size=64,
252+
verbose=False,
253+
num_workers=1,
254+
)
255+
# Return Probabilities is False
256+
output = segmentor.run(
257+
images=[sample_svs],
258+
return_probabilities=True,
259+
return_labels=False,
260+
device=device,
261+
patch_mode=False,
262+
save_dir=tmp_path / "task_length_cache",
263+
batch_size=2,
264+
output_type="zarr",
265+
da_length_threshold=1,
266+
)
267+
268+
output_ = zarr.open(output[sample_svs], mode="r")
269+
assert 0.17 < np.mean(output_["predictions"][:]) < 0.19
270+
assert "probabilities" in output_
271+
assert "canvas" not in output_
272+
assert "count" not in output_
273+
assert "Canvas task graph length:" in caplog.text
274+
249275
# Return Probabilities is True
250276
# Using small image for faster run
251277
segmentor = SemanticSegmentor(

tiatoolbox/models/engine/engine_abc.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
dict_to_store_patch_predictions,
5959
get_tqdm,
6060
)
61+
from tiatoolbox.wsicore.wsireader import is_zarr
6162

6263
from .io_config import ModelIOConfigABC
6364

@@ -644,14 +645,16 @@ def save_predictions(
644645
keys_to_compute = [k for k in processed_predictions if k not in self.drop_keys]
645646

646647
if output_type.lower() == "zarr":
648+
if is_zarr(save_path):
649+
zarr_group = zarr.open(save_path, mode="r")
650+
keys_to_compute = [k for k in keys_to_compute if k not in zarr_group]
647651
write_tasks = []
648652
for key in keys_to_compute:
649653
dask_array = processed_predictions[key]
650654
task = dask_array.to_zarr(
651655
url=save_path,
652656
component=key,
653657
compute=False,
654-
overwrite=True,
655658
)
656659
write_tasks.append(task)
657660
msg = f"Saving output to {save_path}."

0 commit comments

Comments
 (0)