|
9 | 9 | import dask |
10 | 10 | import dask.array as da |
11 | 11 | import numpy as np |
12 | | -import psutil |
13 | 12 | import torch |
14 | 13 | import zarr |
15 | 14 | from dask import compute |
|
34 | 33 | from tiatoolbox.wsicore import WSIReader |
35 | 34 |
|
36 | 35 |
|
37 | | -def smart_divide( |
38 | | - merged_probabilities: zarr.Array, |
39 | | - merged_weights: zarr.Array, |
40 | | - tile_size: int = 2048, |
41 | | - safety_margin: float = 0.5, |
42 | | - *, |
43 | | - verbose: bool = False, |
44 | | -) -> zarr.Array: |
45 | | - """Use chunked division for Zarr if memory is low. |
46 | | -
|
47 | | - Divide merged_probabilities by merged_weights using full-array or chunked strategy |
48 | | - based on available system memory. |
49 | | -
|
50 | | - """ |
51 | | - h, w, c = merged_probabilities.shape |
52 | | - total_elements = h * w * c |
53 | | - estimated_memory = total_elements * 4 * 2 # float32 = 4 bytes, two arrays |
54 | | - |
55 | | - available_memory = psutil.virtual_memory().available |
56 | | - if estimated_memory < available_memory * safety_margin: |
57 | | - # Use full-array division |
58 | | - merged_weights[merged_weights == 0] = 1 |
59 | | - merged_probabilities[:] = merged_probabilities[:] / merged_weights[:] |
60 | | - else: # pragma: no cover |
61 | | - progress_bar = None |
62 | | - tqdm = get_tqdm() |
63 | | - |
64 | | - if verbose: |
65 | | - progress_bar = tqdm( |
66 | | - total=len(range(0, h, tile_size)), |
67 | | - leave=False, |
68 | | - desc="Merging Patches", |
69 | | - ) |
70 | | - # Use chunked division |
71 | | - for i in range(0, h, tile_size): |
72 | | - for j in range(0, w, tile_size): |
73 | | - i_end = min(i + tile_size, h) |
74 | | - j_end = min(j + tile_size, w) |
75 | | - prob_tile = merged_probabilities[i:i_end, j:j_end, :] |
76 | | - weight_tile = merged_weights[i:i_end, j:j_end, :] |
77 | | - weight_tile[weight_tile == 0] = 1 |
78 | | - merged_probabilities[i:i_end, j:j_end, :] = prob_tile / weight_tile |
79 | | - |
80 | | - if progress_bar: |
81 | | - progress_bar.update() |
82 | | - |
83 | | - if progress_bar: |
84 | | - progress_bar.close() |
85 | | - |
86 | | - return merged_probabilities |
87 | | - |
88 | | - |
89 | 36 | class SemanticSegmentorRunParams(PredictorRunParams): |
90 | 37 | """Class describing the input parameters for the :func:`EngineABC.run()` method. |
91 | 38 |
|
|
0 commit comments