@@ -429,7 +429,7 @@ def infer_wsi(
429429 coordinates , labels = [], []
430430
431431 # Main output dictionary
432- raw_predictions = dict (zip (keys , [[] ] * len (keys )))
432+ raw_predictions = dict (zip (keys , [da . empty ( shape = ( 0 , 0 )) ] * len (keys )))
433433
434434 # Inference loop
435435 tqdm = get_tqdm ()
@@ -490,7 +490,7 @@ def infer_wsi(
490490 used_percent > memory_threshold
491491 or ((canvas .nbytes / vm .free ) * 100 ) > memory_threshold
492492 ):
493- tqdm_loop .desc = "Spilling to disk "
493+ tqdm_loop .desc = "Spill intermediate data to disk"
494494 if ((canvas .nbytes / vm .free ) * 100 ) > memory_threshold :
495495 used_percent = (canvas .nbytes / vm .free ) * 100
496496 msg = (
@@ -547,7 +547,7 @@ def infer_wsi(
547547 zarr_group ,
548548 save_path ,
549549 memory_threshold ,
550- ). rechunk ( "auto" )
550+ )
551551 raw_predictions ["coordinates" ] = da .concatenate (coordinates , axis = 0 )
552552 if self .return_labels :
553553 labels = [label .reshape (- 1 ) for label in labels ]
@@ -1008,10 +1008,11 @@ def merge_vertical_chunkwise(
10081008 probabilities = curr_chunk / curr_count .astype (np .float32 )
10091009
10101010 probabilities_zarr , probabilities_da = store_probabilities (
1011- probabilities ,
1012- probabilities_zarr ,
1013- probabilities_da ,
1014- zarr_group ,
1011+ probabilities = probabilities ,
1012+ chunk_shape = chunk_shape ,
1013+ probabilities_zarr = probabilities_zarr ,
1014+ probabilities_da = probabilities_da ,
1015+ zarr_group = zarr_group ,
10151016 )
10161017
10171018 if probabilities_da is not None :
@@ -1049,15 +1050,15 @@ def merge_vertical_chunkwise(
10491050 if "count" in zarr_group :
10501051 del zarr_group ["count" ]
10511052 return da .from_zarr (
1052- probabilities_zarr ,
1053- chunks = "auto" ,
1053+ probabilities_zarr , chunks = (chunk_shape [0 ], * probabilities .shape [1 :])
10541054 )
10551055
10561056 return probabilities_da
10571057
10581058
10591059def store_probabilities (
10601060 probabilities : np .ndarray ,
1061+ chunk_shape : tuple [int , ...],
10611062 probabilities_zarr : zarr .Array | None ,
10621063 probabilities_da : da .Array | None ,
10631064 zarr_group : zarr .Group | None ,
@@ -1071,6 +1072,8 @@ def store_probabilities(
10711072 Args:
10721073 probabilities (np.ndarray):
10731074 Computed probability array to store.
1075+ chunk_shape (tuple[int, ...]):
1076+ Chunk shape used for Zarr dataset creation.
10741077 probabilities_zarr (zarr.Array | None):
10751078 Existing Zarr dataset, or None to initialize.
10761079 probabilities_da (da.Array | None):
@@ -1088,6 +1091,7 @@ def store_probabilities(
10881091 probabilities_zarr = zarr_group .create_dataset (
10891092 name = "probabilities" ,
10901093 shape = (0 , * probabilities .shape [1 :]),
1094+ chunks = (chunk_shape [0 ], * probabilities .shape [1 :]),
10911095 dtype = probabilities .dtype ,
10921096 )
10931097
@@ -1101,7 +1105,9 @@ def store_probabilities(
11011105 else :
11021106 probabilities_da = concatenate_none (
11031107 old_arr = probabilities_da ,
1104- new_arr = da .from_array (probabilities , chunks = "auto" ),
1108+ new_arr = da .from_array (
1109+ probabilities , chunks = (chunk_shape [0 ], * probabilities .shape [1 :])
1110+ ),
11051111 )
11061112
11071113 return probabilities_zarr , probabilities_da
0 commit comments