@@ -41,11 +41,10 @@ def predict(
4141
4242 Returns
4343 -------
44- coords : list of tuples
44+ coords : List[Tuple[int]]
4545 List of (i, j, k) starting coordinates of patches processed.
46- preds : list of numpy.ndarray
46+ preds : List[ numpy.ndarray]
4747 List of predicted patches (3D arrays) matching the patch size.
48-
4948 """
5049 # Initializations
5150 batch_coords , batch_inputs , mn_mx = list (), list (), list ()
@@ -74,7 +73,7 @@ def predict(
7473 # If batch is full or it's the last patch
7574 if len (batch_inputs ) == batch_size or idx == len (coords ) - 1 :
7675 # Run model
77- input_tensor = to_tensor (np .stack (batch_inputs ))
76+ input_tensor = batch_to_tensor (np .stack (batch_inputs ))
7877 with torch .no_grad ():
7978 output_tensor = model (input_tensor )
8079
@@ -92,6 +91,33 @@ def predict(
9291 return coords , preds
9392
9493
94+ def predict_patch (patch , model ):
95+ """
96+ Denoised a single 3D patch using the provided model.
97+
98+ Parameters
99+ ----------
100+ model : torch.nn.Module
101+ PyTorch model used for prediction.
102+ patch : numpy.ndarray
103+ 3D input patch to denoise.
104+
105+ Returns
106+ -------
107+ numpy.ndarray
108+ Denoised 3D patch with the same shape as input patch.
109+ """
110+ # Run model
111+ mn , mx = np .percentile (patch , 5 ), np .percentile (patch , 99.9 )
112+ patch = to_tensor ((patch - mn ) / max (mx , 1 ))
113+ with torch .no_grad ():
114+ output_tensor = model (patch )
115+
116+ # Process output
117+ pred = np .array (output_tensor .cpu ())
118+ return np .maximum (pred [0 , 0 , ...] * mx + mn , 0 ).astype (int )
119+
120+
95121def stitch (img , coords , preds , patch_size = 64 , trim = 5 ):
96122 """
97123 Stitches overlapping 3D patches back into a full denoised image by
@@ -116,28 +142,22 @@ def stitch(img, coords, preds, patch_size=64, trim=5):
116142 numpy.ndarray
117143 Reconstructed image with patches stitched and overlapping areas
118144 averaged.
119-
120145 """
121146 denoised_accum = np .zeros_like (img , dtype = np .float32 )
122147 weight_map = np .zeros_like (img , dtype = np .float32 )
123148 for (i , j , k ), pred in zip (coords , preds ):
124- # Determine how much to trim
125- trim_start = trim
126- trim_end = patch_size - trim
127-
128149 # Trim prediction
129- pred_trimmed = pred [
130- trim_start :trim_end , trim_start :trim_end , trim_start :trim_end
131- ]
150+ start , end = trim , patch_size - trim
151+ pred = pred [start :end , start :end , start :end ]
132152
133153 # Adjust insertion indices
134154 i_start = i + trim
135155 j_start = j + trim
136156 k_start = k + trim
137157
138- i_end = i_start + pred_trimmed .shape [0 ]
139- j_end = j_start + pred_trimmed .shape [1 ]
140- k_end = k_start + pred_trimmed .shape [2 ]
158+ i_end = i_start + pred .shape [0 ]
159+ j_end = j_start + pred .shape [1 ]
160+ k_end = k_start + pred .shape [2 ]
141161
142162 # Clip to image bounds (for safety)
143163 i_end = min (i_end , img .shape [2 ])
@@ -150,9 +170,7 @@ def stitch(img, coords, preds, patch_size=64, trim=5):
150170
151171 denoised_accum [
152172 0 , 0 , i_start :i_end , j_start :j_end , k_start :k_end
153- ] += pred_trimmed [
154- : i_end - i_start , : j_end - j_start , : k_end - k_start
155- ]
173+ ] += pred [: i_end - i_start , : j_end - j_start , : k_end - k_start ]
156174 weight_map [0 , 0 , i_start :i_end , j_start :j_end , k_start :k_end ] += 1
157175
158176 # Average accumulated
@@ -176,7 +194,6 @@ def add_padding(patch, patch_size):
176194 -------
177195 numpy.ndarray
178196 Zero-padded patch with shape (patch_size, patch_size, patch_size).
179-
180197 """
181198 pad_width = [
182199 (0 , patch_size - patch .shape [0 ]),
@@ -205,7 +222,6 @@ def generate_coords(img, patch_size, overlap):
205222 coords : List[Tuple[int]]
206223 List of (depth_start, height_start, width_start) coordinates for image
207224 patches.
208-
209225 """
210226 coords = list ()
211227 stride = patch_size - overlap
@@ -218,20 +234,37 @@ def generate_coords(img, patch_size, overlap):
218234
219235def to_tensor (arr ):
220236 """
221- Converts a NumPy array to a PyTorch tensor with an added channel dimension,
222- and moves it to the GPU.
237+ Converts a NumPy array containing to a PyTorch tensor and moves it to the
238+ GPU.
223239
224240 Parameters
225241 ----------
226242 arr : numpy.ndarray
227- Input array to be converted.
243+ Array to be converted.
228244
229245 Returns
230246 -------
231247 torch.Tensor
232- Input array as a float tensor on the CUDA device, with shape
233- (batch_size, 1, ...).
248+ Tensor on GPU, with shape (1, 1, depth, height, width).
249+ """
250+ while (len (arr .shape )) < 5 :
251+ arr = arr [np .newaxis , ...]
252+ return torch .tensor (arr ).to ("cuda" , dtype = torch .float )
253+
234254
255+ def batch_to_tensor (arr ):
256+ """
257+ Converts a NumPy array containing a batch of inputs to a PyTorch tensor
258+ and moves it to the GPU.
259+
260+ Parameters
261+ ----------
262+ arr : numpy.ndarray
263+ Array to be converted, with shape (batch_size, depth, height, width).
264+
265+ Returns
266+ -------
267+ torch.Tensor
268+ Tensor on GPU, with shape (batch_size, 1, depth, height, width).
235269 """
236- dtype = torch .float
237- return torch .tensor (arr [:, np .newaxis , ...]).to ("cuda" , dtype = dtype )
270+ return to_tensor (arr [:, np .newaxis , ...])
0 commit comments