88import dask .array as da
99import numpy as np
1010import torch
11- from dask .diagnostics import ProgressBar
1211from typing_extensions import Unpack
1312
1413from tiatoolbox import logger
@@ -394,9 +393,6 @@ def infer_wsi(
394393 return_coordinates = True ,
395394 )
396395
397- with ProgressBar ():
398- da_probabilities = raw_predictions ["probabilities" ].persist ()
399-
400396 progress_bar = None
401397 tqdm = get_tqdm ()
402398
@@ -411,13 +407,13 @@ def infer_wsi(
411407 merged_shape = (
412408 max_location [3 ],
413409 max_location [2 ],
414- da_probabilities .shape [3 ],
410+ raw_predictions [ "probabilities" ] .shape [3 ],
415411 )
416412
417413 # creating dask arrays for faster processing
418414 merged_probabilities = da .zeros (
419415 shape = merged_shape ,
420- dtype = da_probabilities .dtype ,
416+ dtype = raw_predictions [ "probabilities" ] .dtype ,
421417 chunks = merged_shape ,
422418 )
423419
@@ -429,7 +425,7 @@ def infer_wsi(
429425
430426 for idx , location in enumerate (self .output_locations ):
431427 start_x , start_y , end_x , end_y = location
432- patch_probs = da_probabilities [
428+ patch_probs = raw_predictions [ "probabilities" ] [
433429 idx , 0 : end_y - start_y , 0 : end_x - start_x , :
434430 ]
435431 merged_probabilities [start_y :end_y , start_x :end_x , :] = (
0 commit comments