Skip to content

Commit a51bab2

Browse files
committed
⚡ Use optimized chunk size
1 parent eeb99df commit a51bab2

File tree

1 file changed

+16
-10
lines changed

1 file changed

+16
-10
lines changed

tiatoolbox/models/engine/semantic_segmentor.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ def infer_wsi(
429429
coordinates, labels = [], []
430430

431431
# Main output dictionary
432-
raw_predictions = dict(zip(keys, [[]] * len(keys)))
432+
raw_predictions = dict(zip(keys, [da.empty(shape=(0, 0))] * len(keys)))
433433

434434
# Inference loop
435435
tqdm = get_tqdm()
@@ -490,7 +490,7 @@ def infer_wsi(
490490
used_percent > memory_threshold
491491
or ((canvas.nbytes / vm.free) * 100) > memory_threshold
492492
):
493-
tqdm_loop.desc = "Spilling to disk "
493+
tqdm_loop.desc = "Spill intermediate data to disk"
494494
if ((canvas.nbytes / vm.free) * 100) > memory_threshold:
495495
used_percent = (canvas.nbytes / vm.free) * 100
496496
msg = (
@@ -547,7 +547,7 @@ def infer_wsi(
547547
zarr_group,
548548
save_path,
549549
memory_threshold,
550-
).rechunk("auto")
550+
)
551551
raw_predictions["coordinates"] = da.concatenate(coordinates, axis=0)
552552
if self.return_labels:
553553
labels = [label.reshape(-1) for label in labels]
@@ -1008,10 +1008,11 @@ def merge_vertical_chunkwise(
10081008
probabilities = curr_chunk / curr_count.astype(np.float32)
10091009

10101010
probabilities_zarr, probabilities_da = store_probabilities(
1011-
probabilities,
1012-
probabilities_zarr,
1013-
probabilities_da,
1014-
zarr_group,
1011+
probabilities=probabilities,
1012+
chunk_shape=chunk_shape,
1013+
probabilities_zarr=probabilities_zarr,
1014+
probabilities_da=probabilities_da,
1015+
zarr_group=zarr_group,
10151016
)
10161017

10171018
if probabilities_da is not None:
@@ -1049,15 +1050,15 @@ def merge_vertical_chunkwise(
10491050
if "count" in zarr_group:
10501051
del zarr_group["count"]
10511052
return da.from_zarr(
1052-
probabilities_zarr,
1053-
chunks="auto",
1053+
probabilities_zarr, chunks=(chunk_shape[0], *probabilities.shape[1:])
10541054
)
10551055

10561056
return probabilities_da
10571057

10581058

10591059
def store_probabilities(
10601060
probabilities: np.ndarray,
1061+
chunk_shape: tuple[int, ...],
10611062
probabilities_zarr: zarr.Array | None,
10621063
probabilities_da: da.Array | None,
10631064
zarr_group: zarr.Group | None,
@@ -1071,6 +1072,8 @@ def store_probabilities(
10711072
Args:
10721073
probabilities (np.ndarray):
10731074
Computed probability array to store.
1075+
chunk_shape (tuple[int, ...]):
1076+
Chunk shape used for Zarr dataset creation.
10741077
probabilities_zarr (zarr.Array | None):
10751078
Existing Zarr dataset, or None to initialize.
10761079
probabilities_da (da.Array | None):
@@ -1088,6 +1091,7 @@ def store_probabilities(
10881091
probabilities_zarr = zarr_group.create_dataset(
10891092
name="probabilities",
10901093
shape=(0, *probabilities.shape[1:]),
1094+
chunks=(chunk_shape[0], *probabilities.shape[1:]),
10911095
dtype=probabilities.dtype,
10921096
)
10931097

@@ -1101,7 +1105,9 @@ def store_probabilities(
11011105
else:
11021106
probabilities_da = concatenate_none(
11031107
old_arr=probabilities_da,
1104-
new_arr=da.from_array(probabilities, chunks="auto"),
1108+
new_arr=da.from_array(
1109+
probabilities, chunks=(chunk_shape[0], *probabilities.shape[1:])
1110+
),
11051111
)
11061112

11071113
return probabilities_zarr, probabilities_da

0 commit comments

Comments
 (0)