Skip to content

Commit 10d09bd

Browse files
committed
update post processing and saving
1 parent 656dbb9 commit 10d09bd

File tree

5 files changed

+333
-539
lines changed

5 files changed

+333
-539
lines changed

tests/engines/test_nucleus_detection_engine.py

Lines changed: 25 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from tiatoolbox.annotation.storage import SQLiteStore
1313
from tiatoolbox.models.engine.nucleus_detector import (
1414
NucleusDetector,
15-
_flatten_predictions_to_dask,
15+
# _flatten_predictions_to_dask,
1616
)
1717
from tiatoolbox.utils import env_detection as toolbox_env
1818
from tiatoolbox.utils.misc import imwrite
@@ -54,7 +54,7 @@ def test_nucleus_detector_wsi(remote_sample: Callable, tmp_path: pathlib.Path) -
5454
store = SQLiteStore.open(save_dir / "wsi4_512_512.db")
5555
assert 255 <= len(store.values()) <= 265
5656
annotation = next(iter(store.values()))
57-
assert annotation.properties["type"] == "test_nucleus"
57+
assert annotation.properties["class"] == "test_nucleus"
5858
store.close()
5959

6060
nucleus_detector.drop_keys = ["probs"]
@@ -68,17 +68,18 @@ def test_nucleus_detector_wsi(remote_sample: Callable, tmp_path: pathlib.Path) -
6868
overwrite=True,
6969
batch_size=8,
7070
)
71+
print("Result path:", result_path)
7172

7273
zarr_path = result_path[mini_wsi_svs]
7374
zarr_group = zarr.open(zarr_path, mode="r")
7475
xs = zarr_group["x"][:]
7576
ys = zarr_group["y"][:]
76-
types = zarr_group["types"][:]
77+
classes = zarr_group["classes"][:]
7778
probs = zarr_group.get("probs", None)
7879
assert probs is None
7980
assert 255 <= len(xs) <= 265
8081
assert 255 <= len(ys) <= 265
81-
assert 255 <= len(types) <= 265
82+
assert 255 <= len(classes) <= 265
8283

8384
_rm_dir(save_dir)
8485
pathlib.Path.unlink(mini_wsi_svs)
@@ -174,17 +175,16 @@ def test_nucleus_detector_patches_dict_output(
174175
save_dir=None,
175176
class_dict=None,
176177
)
177-
output_dict = output_dict["predictions"]
178178
assert len(output_dict["x"]) == 2
179179
assert len(output_dict["y"]) == 2
180-
assert len(output_dict["types"]) == 2
180+
assert len(output_dict["classes"]) == 2
181181
assert len(output_dict["probs"]) == 2
182182
assert len(output_dict["x"][0]) == 1
183183
assert len(output_dict["x"][1]) == 0
184184
assert len(output_dict["y"][0]) == 1
185185
assert len(output_dict["y"][1]) == 0
186-
assert len(output_dict["types"][0]) == 1
187-
assert len(output_dict["types"][1]) == 0
186+
assert len(output_dict["classes"][0]) == 1
187+
assert len(output_dict["classes"][1]) == 0
188188
assert len(output_dict["probs"][0]) == 1
189189
assert len(output_dict["probs"][1]) == 0
190190

@@ -220,40 +220,16 @@ def test_nucleus_detector_patches_zarr_output(
220220
overwrite=True,
221221
)
222222

223-
zarr_group = zarr.open(output_path, mode="r")
224-
output_dict = {
225-
"x": zarr_group["x"][:],
226-
"y": zarr_group["y"][:],
227-
"types": zarr_group["types"][:],
228-
"probs": zarr_group["probs"][:],
229-
"patch_offsets": zarr_group["patch_offsets"][:],
230-
}
231-
232-
assert len(output_dict["x"]) == 1
233-
assert len(output_dict["y"]) == 1
234-
assert len(output_dict["types"]) == 1
235-
assert len(output_dict["probs"]) == 1
236-
assert len(output_dict["patch_offsets"]) == 3
237-
238-
patch_1_start, patch_1_end = (
239-
output_dict["patch_offsets"][0],
240-
output_dict["patch_offsets"][1],
241-
)
242-
patch_2_start, patch_2_end = (
243-
output_dict["patch_offsets"][1],
244-
output_dict["patch_offsets"][2],
245-
)
246-
assert len(output_dict["x"][patch_1_start:patch_1_end]) == 1
247-
assert len(output_dict["x"][patch_2_start:patch_2_end]) == 0
248-
249-
assert len(output_dict["y"][patch_1_start:patch_1_end]) == 1
250-
assert len(output_dict["y"][patch_2_start:patch_2_end]) == 0
223+
output_zarr = zarr.open(output_path, mode="r")
251224

252-
assert len(output_dict["types"][patch_1_start:patch_1_end]) == 1
253-
assert len(output_dict["types"][patch_2_start:patch_2_end]) == 0
254-
255-
assert len(output_dict["probs"][patch_1_start:patch_1_end]) == 1
256-
assert len(output_dict["probs"][patch_2_start:patch_2_end]) == 0
225+
assert output_zarr["x"][0].size == 1
226+
assert output_zarr["x"][1].size == 0
227+
assert output_zarr["y"][0].size == 1
228+
assert output_zarr["y"][1].size == 0
229+
assert output_zarr["classes"][0].size == 1
230+
assert output_zarr["classes"][1].size == 0
231+
assert output_zarr["probs"][0].size == 1
232+
assert output_zarr["probs"][1].size == 0
257233

258234
_rm_dir(save_dir)
259235

@@ -269,12 +245,12 @@ def test_centroid_maps_to_detection_arrays() -> None:
269245

270246
xs = detections["x"]
271247
ys = detections["y"]
272-
types = detections["types"]
248+
classes = detections["classes"]
273249
probs = detections["probs"]
274250

275251
np.testing.assert_array_equal(xs, np.array([1, 3], dtype=np.uint32))
276252
np.testing.assert_array_equal(ys, np.array([1, 2], dtype=np.uint32))
277-
np.testing.assert_array_equal(types, np.array([0, 1], dtype=np.uint32))
253+
np.testing.assert_array_equal(classes, np.array([0, 1], dtype=np.uint32))
278254
np.testing.assert_array_equal(probs, np.array([1.0, 0.5], dtype=np.float32))
279255

280256

@@ -283,65 +259,36 @@ def test_write_detection_arrays_to_store() -> None:
283259
detection_arrays = {
284260
"x": np.array([1, 3], dtype=np.uint32),
285261
"y": np.array([1, 2], dtype=np.uint32),
286-
"types": np.array([0, 1], dtype=np.uint32),
262+
"classes": np.array([0, 1], dtype=np.uint32),
287263
"probs": np.array([1.0, 0.5], dtype=np.float32),
288264
}
289265

290-
store = NucleusDetector.write_detection_arrays_to_store(detection_arrays)
266+
store = NucleusDetector.save_detection_arrays_to_store(detection_arrays)
291267
assert len(store.values()) == 2
292268

293269
detection_arrays = {
294270
"x": np.array([1], dtype=np.uint32),
295271
"y": np.array([1, 2], dtype=np.uint32),
296-
"types": np.array([0], dtype=np.uint32),
272+
"classes": np.array([0], dtype=np.uint32),
297273
"probs": np.array([1.0, 0.5], dtype=np.float32),
298274
}
299275
with pytest.raises(
300276
ValueError,
301277
match=r"Detection record lengths are misaligned.",
302278
):
303-
_ = NucleusDetector.write_detection_arrays_to_store(detection_arrays)
279+
_ = NucleusDetector.save_detection_arrays_to_store(detection_arrays)
304280

305281

306282
def test_write_detection_records_to_store_no_class_dict() -> None:
307283
"""Test writing detection records to annotation store."""
308284
detection_records = (np.array([1]), np.array([2]), np.array([0]), np.array([1.0]))
309285

310286
dummy_store = SQLiteStore()
311-
total = NucleusDetector._write_detection_records_to_store(
287+
total = NucleusDetector._write_detection_arrays_to_store(
312288
detection_records, store=dummy_store, scale_factor=(1.0, 1.0), class_dict=None
313289
)
314290
assert len(dummy_store.values()) == 1
315291
assert total == 1
316292
annotation = next(iter(dummy_store.values()))
317-
assert annotation.properties["type"] == 0
293+
assert annotation.properties["class"] == 0
318294
dummy_store.close()
319-
320-
321-
def test_flatten_predictions_to_dask() -> None:
322-
"""Test flattening ragged predictions to Dask array."""
323-
ragged_obj_array = np.empty(3, dtype=object)
324-
ragged_obj_array[0] = np.array([1.0, 0.0], dtype=np.float32)
325-
ragged_obj_array[1] = np.array([0.5, 0.5], dtype=np.float32)
326-
ragged_obj_array[2] = np.array([0.2, 0.8, 0.8, 0.2], dtype=np.float32)
327-
328-
ragged_da_array = da.from_array(ragged_obj_array, chunks=(len(ragged_obj_array),))
329-
330-
flat_dask_array = _flatten_predictions_to_dask(ragged_da_array)
331-
expected_array = np.array(
332-
[
333-
1.0,
334-
0.0,
335-
0.5,
336-
0.5,
337-
0.2,
338-
0.8,
339-
0.8,
340-
0.2,
341-
],
342-
dtype=np.float32,
343-
)
344-
np.testing.assert_array_equal(flat_dask_array.compute(), expected_array)
345-
346-
flat_dask_array = _flatten_predictions_to_dask(ragged_obj_array)
347-
np.testing.assert_array_equal(flat_dask_array.compute(), expected_array)

tests/models/test_arch_mapde.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from tiatoolbox.models import MapDe
1010
from tiatoolbox.models.architecture import fetch_pretrained_weights
11-
from tiatoolbox.models.engine.nucleus_detector import NucleusDetector
1211
from tiatoolbox.utils import env_detection as toolbox_env
1312
from tiatoolbox.utils.misc import select_device
1413
from tiatoolbox.wsicore.wsireader import WSIReader
@@ -49,7 +48,11 @@ def test_functionality(remote_sample: Callable) -> None:
4948
batch = torch.from_numpy(patch)[None]
5049
output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU))
5150
output = model.postproc(output[0])
52-
xs, ys, _, _ = NucleusDetector._extract_detection_arrays_from_block(output, None)
51+
(
52+
ys,
53+
xs,
54+
_,
55+
) = np.nonzero(output)
5356

5457
np.testing.assert_array_equal(xs[0:2], np.array([242, 192]))
5558
np.testing.assert_array_equal(ys[0:2], np.array([10, 13]))
@@ -73,7 +76,7 @@ def test_functionality(remote_sample: Callable) -> None:
7376
}
7477
}
7578
output = model.postproc(output[0], block_info=block_info)
76-
xs, ys, _, _ = NucleusDetector._extract_detection_arrays_from_block(output, None)
79+
ys, xs, _ = np.nonzero(output)
7780
np.testing.assert_array_equal(xs, np.array([]))
7881
np.testing.assert_array_equal(ys, np.array([]))
7982

tests/models/test_arch_sccnn.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
from tiatoolbox.models import SCCNN
99
from tiatoolbox.models.architecture import fetch_pretrained_weights
10-
from tiatoolbox.models.engine.nucleus_detector import NucleusDetector
1110
from tiatoolbox.utils import env_detection
1211
from tiatoolbox.utils.misc import select_device
1312
from tiatoolbox.wsicore.wsireader import WSIReader
@@ -49,7 +48,7 @@ def test_functionality(remote_sample: Callable) -> None:
4948
device=select_device(on_gpu=env_detection.has_gpu()),
5049
)
5150
output = model.postproc(output[0])
52-
xs, ys, _, _ = NucleusDetector._extract_detection_arrays_from_block(output, None)
51+
ys, xs, _ = np.nonzero(output)
5352

5453
np.testing.assert_array_equal(xs, np.array([8]))
5554
np.testing.assert_array_equal(ys, np.array([7]))
@@ -66,7 +65,7 @@ def test_functionality(remote_sample: Callable) -> None:
6665
}
6766
}
6867
output = model.postproc(output[0], block_info=block_info)
69-
xs, ys, _, _ = NucleusDetector._extract_detection_arrays_from_block(output, None)
68+
ys, xs, _ = np.nonzero(output)
7069
np.testing.assert_array_equal(xs, np.array([7]))
7170
np.testing.assert_array_equal(ys, np.array([8]))
7271

@@ -85,6 +84,6 @@ def test_functionality(remote_sample: Callable) -> None:
8584
}
8685
}
8786
output = model.postproc(output[0], block_info=block_info)
88-
xs, ys, _, _ = NucleusDetector._extract_detection_arrays_from_block(output, None)
87+
ys, xs, _, _ = np.nonzero(output)
8988
np.testing.assert_array_equal(xs, np.array([]))
9089
np.testing.assert_array_equal(ys, np.array([]))

tiatoolbox/models/engine/engine_abc.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
import zarr
4747
from dask import compute
4848
from dask.diagnostics.progress import ProgressBar
49+
from numcodecs import Pickle
4950
from torch import nn
5051
from typing_extensions import Unpack
5152

@@ -704,13 +705,29 @@ def save_predictions(
704705
keys_to_compute = [k for k in keys_to_compute if k not in zarr_group]
705706
write_tasks = []
706707
for key in keys_to_compute:
707-
dask_array = processed_predictions[key].rechunk("auto")
708-
task = dask_array.to_zarr(
709-
url=save_path,
710-
component=key,
711-
compute=False,
712-
)
713-
write_tasks.append(task)
708+
dask_output = processed_predictions[key]
709+
if isinstance(dask_output, da.Array):
710+
dask_output = dask_output.rechunk("auto")
711+
task = dask_output.to_zarr(
712+
url=save_path, component=key, compute=False, object_codec=None
713+
)
714+
write_tasks.append(task)
715+
716+
if isinstance(dask_output, list) and all(
717+
isinstance(dask_array, da.Array) for dask_array in dask_output
718+
):
719+
for i, dask_array in enumerate(dask_output):
720+
object_codec = (
721+
Pickle() if dask_array.dtype == "object" else None
722+
)
723+
task = dask_array.to_zarr(
724+
url=save_path,
725+
component=f"{key}/{i}",
726+
compute=False,
727+
object_codec=object_codec,
728+
)
729+
write_tasks.append(task)
730+
714731
msg = f"Saving output to {save_path}."
715732
logger.info(msg=msg)
716733
with ProgressBar():

0 commit comments

Comments
 (0)