1111import numpy as np
1212import torch
1313
14- from aind_exaspim_image_compression .utils import img_util , util
14+ from aind_exaspim_image_compression .utils import img_util
1515
1616
17- def predict (img , model , batch_size = 32 , patch_size = 64 , overlap = 16 ):
17+ def predict (
18+ img , model , batch_size = 32 , patch_size = 64 , overlap = 16 , verbose = True
19+ ):
1820 # Initializations
1921 batch_coords , batch_inputs , mn_mx = list (), list (), list ()
2022 coords = generate_coords (img , patch_size , overlap )
2123
2224 # Main
23- pbar = tqdm (total = len (coords ), desc = "Denoise" )
25+ pbar = tqdm (total = len (coords ), desc = "Denoise" ) if verbose else None
2426 preds = list ()
2527 for idx , (i , j , k ) in enumerate (coords ):
2628 # Get end coord
27- i_end = min (i + patch_size , img .shape [0 ])
28- j_end = min (j + patch_size , img .shape [1 ])
29- k_end = min (k + patch_size , img .shape [2 ])
29+ i_end = min (i + patch_size , img .shape [2 ])
30+ j_end = min (j + patch_size , img .shape [3 ])
31+ k_end = min (k + patch_size , img .shape [4 ])
3032
3133 # Get patch
32- patch = img [i :i_end , j :j_end , k :k_end ]
34+ patch = img [0 , 0 , i :i_end , j :j_end , k :k_end ]
3335 mn , mx = np .percentile (patch , 5 ), np .percentile (patch , 99.9 )
3436 patch = (patch - mn ) / mx
3537 mn_mx .append ((mn , mx ))
@@ -42,7 +44,7 @@ def predict(img, model, batch_size=32, patch_size=64, overlap=16):
4244 # If batch is full or it's the last patch
4345 if len (batch_inputs ) == batch_size or idx == len (coords ) - 1 :
4446 # Run model
45- input_tensor = to_tensor (np .stack (batch_inputs ))
47+ input_tensor = to_tensor (np .stack (batch_inputs ))
4648 with torch .no_grad ():
4749 output_tensor = model (input_tensor )
4850
@@ -51,8 +53,8 @@ def predict(img, model, batch_size=32, patch_size=64, overlap=16):
5153 for cnt in range (output_tensor .shape [0 ]):
5254 mn , mx = mn_mx [cnt ]
5355 patch = np .array (output_tensor [cnt , 0 , ...])
54- preds .append (patch * mx + mn )
55- pbar .update (1 )
56+ preds .append (np . maximum ( patch * mx + mn , 0 ) )
57+ pbar .update (1 ) if verbose else None
5658
5759 batch_coords .clear ()
5860 batch_inputs .clear ()
@@ -63,14 +65,15 @@ def predict(img, model, batch_size=32, patch_size=64, overlap=16):
6365def stitch (img , coords , preds , patch_size = 64 , trim = 5 ):
6466 denoised_accum = np .zeros_like (img , dtype = np .float32 )
6567 weight_map = np .zeros_like (img , dtype = np .float32 )
66-
6768 for (i , j , k ), pred in zip (coords , preds ):
6869 # Determine how much to trim
6970 trim_start = trim
7071 trim_end = patch_size - trim
7172
7273 # Trim prediction
73- pred_trimmed = pred [trim_start :trim_end , trim_start :trim_end , trim_start :trim_end ]
74+ pred_trimmed = pred [
75+ trim_start :trim_end , trim_start :trim_end , trim_start :trim_end
76+ ]
7477
7578 # Adjust insertion indices
7679 i_start = i + trim
@@ -82,23 +85,29 @@ def stitch(img, coords, preds, patch_size=64, trim=5):
8285 k_end = k_start + pred_trimmed .shape [2 ]
8386
8487 # Clip to image bounds (for safety)
85- i_end = min (i_end , img .shape [0 ])
86- j_end = min (j_end , img .shape [1 ])
87- k_end = min (k_end , img .shape [2 ])
88+ i_end = min (i_end , img .shape [2 ])
89+ j_end = min (j_end , img .shape [3 ])
90+ k_end = min (k_end , img .shape [4 ])
8891
8992 i_start = max (i_start , 0 )
9093 j_start = max (j_start , 0 )
9194 k_start = max (k_start , 0 )
9295
93- denoised_accum [i_start :i_end , j_start :j_end , k_start :k_end ] += pred_trimmed [:i_end - i_start , :j_end - j_start , :k_end - k_start ]
94- weight_map [i_start :i_end , j_start :j_end , k_start :k_end ] += 1
96+ denoised_accum [
97+ 0 , 0 , i_start :i_end , j_start :j_end , k_start :k_end
98+ ] += pred_trimmed [
99+ : i_end - i_start , : j_end - j_start , : k_end - k_start
100+ ]
101+ weight_map [0 , 0 , i_start :i_end , j_start :j_end , k_start :k_end ] += 1
95102
96103 # Average accumulated
97104 weight_map [weight_map == 0 ] = 1
98105 denoised = denoised_accum / weight_map
99106
100107 # Fill boundary trim
101- fill_value = np .percentile (denoised [trim :- trim , trim :- trim , trim :- trim ], 10 )
108+ fill_value = np .percentile (
109+ denoised [..., trim :- trim , trim :- trim , trim :- trim ], 10
110+ )
102111 return img_util .fill_boundary (denoised , trim , fill_value )
103112
104113
@@ -109,18 +118,19 @@ def add_padding(patch, patch_size):
109118 (0 , patch_size - patch .shape [1 ]),
110119 (0 , patch_size - patch .shape [2 ]),
111120 ]
112- return np .pad (patch , pad_width , mode = ' constant' , constant_values = 0 )
121+ return np .pad (patch , pad_width , mode = " constant" , constant_values = 0 )
113122
114123
115124def generate_coords (img , patch_size , overlap ):
116125 coords = list ()
117126 stride = patch_size - overlap
118- for i in range (0 , img .shape [0 ] - patch_size + stride , stride ):
119- for j in range (0 , img .shape [1 ] - patch_size + stride , stride ):
120- for k in range (0 , img .shape [2 ] - patch_size + stride , stride ):
127+ for i in range (0 , img .shape [2 ] - patch_size + stride , stride ):
128+ for j in range (0 , img .shape [3 ] - patch_size + stride , stride ):
129+ for k in range (0 , img .shape [4 ] - patch_size + stride , stride ):
121130 coords .append ((i , j , k ))
122131 return coords
123132
124133
125134def to_tensor (arr ):
126- return torch .tensor (arr [:, np .newaxis , ...]).to ("cuda" )
135+ dtype = torch .float
136+ return torch .tensor (arr [:, np .newaxis , ...]).to ("cuda" , dtype = dtype )
0 commit comments