1010into a full 3D volume.
1111
1212"""
13-
13+ from concurrent .futures import (
14+ ThreadPoolExecutor ,
15+ as_completed ,
16+ )
1417from tqdm import tqdm
1518
1619import numpy as np
1720import torch
1821
1922
2023def predict (
21- img , model , batch_size = 32 , patch_size = 64 , overlap = 16 , verbose = True
24+ img , model , batch_size = 32 , patch_size = 64 , overlap = 16 , trim = 5 , verbose = True
2225):
2326 """
2427 Denoises a 3D image by processing patches in batches and running deep
@@ -46,51 +49,76 @@ def predict(
4649 preds : List[numpy.ndarray]
4750 List of predicted patches (3D arrays) matching the patch size.
4851 """
49- # Initializations
50- batch_coords , batch_inputs , mn_mx = list (), list (), list ()
52+ # Adjust image dimenions
5153 while len (img .shape ) < 5 :
5254 img = img [np .newaxis , ...]
53- coords = generate_coords (img , patch_size , overlap )
55+
56+ # Initializations
57+ starts = generate_patch_starts (img , patch_size , overlap )
58+ denoised = np .zeros_like (img , dtype = np .uint16 )
5459
5560 # Main
56- pbar = tqdm (total = len (coords ), desc = "Denoise" ) if verbose else None
57- preds = list ()
58- for idx , (i , j , k ) in enumerate (coords ):
59- # Get end coord
60- i_end = min (i + patch_size , img .shape [2 ])
61- j_end = min (j + patch_size , img .shape [3 ])
62- k_end = min (k + patch_size , img .shape [4 ])
63-
64- # Get patch
65- patch = img [0 , 0 , i :i_end , j :j_end , k :k_end ]
61+ pbar = tqdm (total = len (starts ), desc = "Denoise" ) if verbose else None
62+ for i in range (0 , len (starts ), batch_size ):
63+ # Run model
64+ starts_i = starts [i :min (i + batch_size , len (starts ))]
65+ patches_i = _predict_batch (img , model , starts_i , patch_size , trim )
66+
67+ # Store result
68+ for patch , start in zip (patches_i , starts_i ):
69+ start = [max (s + trim , 0 ) for s in start ]
70+ end = [start [i ] + patch .shape [i ] for i in range (3 )]
71+ end = [min (e , s ) for e , s in zip (end , img .shape [2 :])]
72+ denoised [
73+ 0 , 0 , start [0 ]:end [0 ], start [1 ]:end [1 ], start [2 ]:end [2 ]
74+ ] = patch [: end [0 ] - start [0 ], : end [1 ] - start [1 ], : end [2 ] - start [2 ]]
75+ pbar .update (len (starts_i )) if verbose else None
76+ return denoised
77+
78+
79+ def _predict_batch (img , model , starts , patch_size , trim = 5 ):
80+ # Subroutine
81+ def read_patch (i ):
82+ start = starts [i ]
83+ end = [min (s + patch_size , d ) for s , d in zip (start , img .shape [2 :])]
84+ patch = img [0 , 0 , start [0 ]:end [0 ], start [1 ]:end [1 ], start [2 ]:end [2 ]]
6685 mn , mx = np .percentile (patch , 5 ), np .percentile (patch , 99.9 )
67- patch = (patch - mn ) / mx
68- mn_mx .append ((mn , mx ))
69-
70- # Store patch
71- patch = add_padding (patch , patch_size )
72- batch_inputs .append (patch )
73- batch_coords .append ((i , j , k ))
74-
75- # If batch is full or it's the last patch
76- if len (batch_inputs ) == batch_size or idx == len (coords ) - 1 :
77- # Run model
78- input_tensor = batch_to_tensor (np .stack (batch_inputs ))
79- with torch .no_grad ():
80- output_tensor = model (input_tensor )
81-
82- # Store result
83- output_tensor = output_tensor .cpu ()
84- for cnt in range (output_tensor .shape [0 ]):
85- mn , mx = mn_mx [cnt ]
86- patch = np .array (output_tensor [cnt , 0 , ...])
87- preds .append (np .maximum (patch * mx + mn , 0 ))
88- pbar .update (1 ) if verbose else None
89-
90- batch_coords .clear ()
91- batch_inputs .clear ()
92- mn_mx .clear ()
93- return stitch (img , coords , preds )
86+ patch = add_padding ((patch - mn ) / mx , patch_size )
87+ return i , patch , (mn , mx )
88+
89+ # Main
90+ with ThreadPoolExecutor () as executor :
91+ # Read patches
92+ threads = list ()
93+ for i in range (len (starts )):
94+ threads .append (executor .submit (read_patch , i ))
95+
96+ # Compile batch
97+ inputs = np .zeros ((len (starts ),) + (patch_size ,) * 3 )
98+ mn_mx = len (starts ) * [None ]
99+ for thread in as_completed (threads ):
100+ i , patch_i , mn_mx_i = thread .result ()
101+ mn_mx [i ] = mn_mx_i
102+ inputs [i , ...] = patch_i
103+
104+ # Run model
105+ inputs = batch_to_tensor (inputs )
106+ with torch .no_grad ():
107+ outputs = model (inputs )
108+ outputs = np .array (outputs .cpu ()).squeeze (1 )
109+
110+ # Store result
111+ preds = list ()
112+ start , end = trim , patch_size - trim
113+ for i in range (outputs .shape [0 ]):
114+ mn , mx = mn_mx [i ]
115+ pred = np .maximum (outputs [i ] * mx + mn , 0 ).astype (np .uint16 )
116+ preds .append (pred [start :end , start :end , start :end ])
117+ return preds
118+
119+
120+ def predict_largescale (img , model ):
121+ pass
94122
95123
96124def predict_patch (patch , model ):
@@ -117,67 +145,7 @@ def predict_patch(patch, model):
117145
118146 # Process output
119147 pred = np .array (output_tensor .cpu ())
120- return np .maximum (pred [0 , 0 , ...] * mx + mn , 0 ).astype (int )
121-
122-
123- def stitch (img , coords , preds , patch_size = 64 , trim = 5 ):
124- """
125- Stitches overlapping 3D patches back into a full denoised image by
126- averaging overlapping regions, with optional trimming of patch borders.
127-
128- Parameters
129- ----------
130- img : numpy.ndarray
131- Original image array of shape (batch, channels, depth, height, width).
132- coords : List[Tuple[int]]
133- List of starting (i, j, k) coordinates for each patch.
134- preds : List[numpy.ndarray]
135- Predicted patches with shape (patch_size, patch_size, patch_size).
136- patch_size : int, optional
137- Size of each cubic patch. Default is 64.
138- trim : int, optional
139- Number of voxels to trim from each side of a patch before stitching.
140- Default is 5.
141-
142- Returns
143- -------
144- numpy.ndarray
145- Reconstructed image with patches stitched and overlapping areas
146- averaged.
147- """
148- denoised_accum = np .zeros_like (img , dtype = np .float32 )
149- weight_map = np .zeros_like (img , dtype = np .float32 )
150- for (i , j , k ), pred in zip (coords , preds ):
151- # Trim prediction
152- start , end = trim , patch_size - trim
153- pred = pred [start :end , start :end , start :end ]
154-
155- # Adjust insertion indices
156- i_start = i + trim
157- j_start = j + trim
158- k_start = k + trim
159-
160- i_end = i_start + pred .shape [0 ]
161- j_end = j_start + pred .shape [1 ]
162- k_end = k_start + pred .shape [2 ]
163-
164- # Clip to image bounds (for safety)
165- i_end = min (i_end , img .shape [2 ])
166- j_end = min (j_end , img .shape [3 ])
167- k_end = min (k_end , img .shape [4 ])
168-
169- i_start = max (i_start , 0 )
170- j_start = max (j_start , 0 )
171- k_start = max (k_start , 0 )
172-
173- denoised_accum [
174- 0 , 0 , i_start :i_end , j_start :j_end , k_start :k_end
175- ] += pred [: i_end - i_start , : j_end - j_start , : k_end - k_start ]
176- weight_map [0 , 0 , i_start :i_end , j_start :j_end , k_start :k_end ] += 1
177-
178- # Average accumulated
179- weight_map [weight_map == 0 ] = 1
180- return denoised_accum / weight_map
148+ return np .maximum (pred [0 , 0 , ...] * mx + mn , 0 ).astype (np .uint16 )
181149
182150
183151# --- Helpers ---
@@ -205,7 +173,7 @@ def add_padding(patch, patch_size):
205173 return np .pad (patch , pad_width , mode = "constant" , constant_values = 0 )
206174
207175
208- def generate_coords (img , patch_size , overlap ):
176+ def generate_patch_starts (img , patch_size , overlap ):
209177 """
210178 Generates starting coordinates for 3D patches extracted from an image
211179 tensor, based on specified patch size and overlap.
0 commit comments