Skip to content

Commit 5b0d6d5

Browse files
committed
🐛 Fix bug for non-overlapping regions and add tests
1 parent e0f9405 commit 5b0d6d5

File tree

2 files changed

+47
-2
lines changed

2 files changed

+47
-2
lines changed

tests/engines/test_semantic_segmentor.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,12 @@
44

55
import json
66
import sqlite3
7+
import tempfile
78
from pathlib import Path
89
from typing import TYPE_CHECKING, Callable
10+
from unittest import mock
911

12+
import dask.array as da
1013
import numpy as np
1114
import torch
1215
import zarr
@@ -15,7 +18,10 @@
1518
from tiatoolbox import cli
1619
from tiatoolbox.annotation import SQLiteStore
1720
from tiatoolbox.models.engine import semantic_segmentor
18-
from tiatoolbox.models.engine.semantic_segmentor import SemanticSegmentor
21+
from tiatoolbox.models.engine.semantic_segmentor import (
22+
SemanticSegmentor,
23+
merge_vertical_chunkwise,
24+
)
1925
from tiatoolbox.utils import env_detection as toolbox_env
2026
from tiatoolbox.utils.misc import imread
2127

@@ -275,6 +281,45 @@ def test_empty_blocks() -> None:
275281
assert np.array_equal(count, np.zeros((2, 2, 1), dtype=np.uint8))
276282

277283

284+
def test_merge_vertical_chunkwise_memory_threshold_triggered() -> None:
285+
"""Test merge vertical chunkwise for memory threshold."""
286+
# Create dummy canvas and count arrays with 3 vertical chunks
287+
data = np.ones((30, 10), dtype=np.uint8)
288+
canvas = da.from_array(data, chunks=(10, 10))
289+
count = da.from_array(data, chunks=(10, 10))
290+
291+
# Output locations to simulate overlaps
292+
output_locs_y_ = np.array([[0, 10], [10, 20], [20, 30]])
293+
294+
# Temporary Zarr group
295+
with tempfile.TemporaryDirectory() as tmpdir:
296+
save_path = Path(tmpdir)
297+
298+
# Mock psutil to simulate low memory
299+
with mock.patch(
300+
"tiatoolbox.models.engine.semantic_segmentor.psutil.virtual_memory"
301+
) as mock_vm:
302+
mock_vm.return_value.free = 1 # Very low free memory
303+
304+
result = merge_vertical_chunkwise(
305+
canvas=canvas,
306+
count=count,
307+
output_locs_y_=output_locs_y_,
308+
zarr_group=None,
309+
save_path=save_path,
310+
memory_threshold=0.01, # Very low threshold to trigger the condition
311+
)
312+
313+
# Assertions
314+
assert isinstance(result, da.Array)
315+
assert hasattr(result, "name")
316+
assert result.name.startswith("from-zarr")
317+
assert np.all(result.compute() == data)
318+
319+
zarr_group = zarr.open(tmpdir, mode="r")
320+
assert np.all(zarr_group["probabilities"][:] == data)
321+
322+
278323
def test_wsi_segmentor_zarr(
279324
remote_sample: Callable,
280325
sample_svs: Path,

tiatoolbox/models/engine/semantic_segmentor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1002,7 +1002,7 @@ def merge_vertical_chunkwise(
10021002
next_count = count.blocks[1, 0].compute() if num_chunks > 1 else None
10031003

10041004
for i, overlap in enumerate(tqdm_loop):
1005-
if next_chunk is not None:
1005+
if next_chunk is not None and overlap > 0:
10061006
curr_chunk[-overlap:] += next_chunk[:overlap]
10071007
curr_count[-overlap:] += next_count[:overlap]
10081008

0 commit comments

Comments
 (0)