Skip to content

Commit fea9edf

Browse files
committed
fix: minor bug fixes
1 parent b67b568 commit fea9edf

File tree

7 files changed

+83
-116
lines changed

7 files changed

+83
-116
lines changed

cellseg_models_pytorch/inference/inferer.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from functools import partial
2-
from itertools import chain
32
from typing import Any, Dict, Generator, List, Tuple
43

54
import numpy as np
@@ -443,8 +442,8 @@ def _check_and_set_head_args(self) -> None:
443442

444443
def _get_out_info(self) -> Tuple[Tuple[str, int]]:
445444
"""Get the output names and number of out channels."""
446-
return tuple(
447-
chain.from_iterable(
448-
list(self.model.heads[k].items()) for k in self.model.heads.keys()
449-
)
450-
)
445+
return [
446+
(f"{decoder}-{head}", n)
447+
for decoder, inner in self.model.heads.items()
448+
for head, n in inner.items()
449+
]

cellseg_models_pytorch/models/cellpose/cellpose.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def __init__(
2929
enc_pretrain: bool = True,
3030
enc_freeze: bool = False,
3131
enc_out_indices: Tuple[int, ...] = None,
32-
upsampling: str = "fixed-unpool",
32+
upsampling: str = "bilinear",
3333
long_skip: str = "unet",
3434
merge_policy: str = "sum",
3535
short_skip: str = "basic",
@@ -131,6 +131,7 @@ def __init__(
131131

132132
self.enc_freeze = enc_freeze
133133
use_style = style_channels is not None
134+
self.decoders = decoders
134135
self.heads = heads
135136

136137
# Create build args

cellseg_models_pytorch/transforms/albu_transforms/norm_transforms.py

Lines changed: 26 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,14 @@ def __init__(
2727
) -> None:
2828
"""Min-max normalization transformation.
2929
30-
Parameters
31-
----------
32-
amin : float, optional
30+
Parameters:
31+
amin (float, default=None)
3332
Clamp min value. No clamping performed if None.
34-
amax : float, optional
33+
amax (float, default=None)
3534
Clamp max value. No clamping performed if None.
36-
p : float, default=1.0
35+
p (float, default=1.0):
3736
Probability of applying the transformation.
38-
copy : bool, default=False
37+
copy (bool, default=False):
3938
If True, normalize the copy of the input.
4039
"""
4140
if not HAS_ALBU:
@@ -52,13 +51,11 @@ def __init__(
5251
def apply(self, image: np.ndarray, **kwargs) -> np.ndarray:
5352
"""Apply min-max normalization.
5453
55-
Parameters
56-
----------
57-
image : np.ndarray:
54+
Parameters:
55+
image (np.ndarray):
5856
Input image to be normalized. Shape (H, W, C)|(H, W).
5957
60-
Returns
61-
-------
58+
Returns:
6259
np.ndarray:
6360
Normalized image. Same shape as input. dtype: float32.
6461
"""
@@ -80,15 +77,14 @@ def __init__(
8077
) -> None:
8178
"""Percentile normalization transformation.
8279
83-
Parameters
84-
----------
85-
amin : float, optional
80+
Parameters:
81+
amin (float, default=None):
8682
Clamp min value. No clamping performed if None.
87-
amax : float, optional
83+
amax (float, default=None):
8884
Clamp max value. No clamping performed if None.
89-
p : float, default=1.0
85+
p (float, default=1.0):
9086
Probability of applying the transformation.
91-
copy : bool, default=False
87+
copy (bool, default=False):
9288
If True, normalize the copy of the input.
9389
"""
9490
if not HAS_ALBU:
@@ -105,13 +101,11 @@ def __init__(
105101
def apply(self, image: np.ndarray, **kwargs) -> np.ndarray:
106102
"""Apply percentile normalization to input image.
107103
108-
Parameters
109-
----------
110-
image : np.ndarray:
104+
Parameters:
105+
image (np.ndarray):
111106
Input image to be normalized. Shape (H, W, C)|(H, W).
112107
113-
Returns
114-
-------
108+
Returns:
115109
np.ndarray:
116110
Normalized image. Same shape as input. dtype: float32.
117111
"""
@@ -136,19 +130,18 @@ def __init__(
136130
137131
NOTE: this is not dataset-level normalization but image-level.
138132
139-
Parameters
140-
----------
141-
standardize : bool, default=True
133+
Parameters:
134+
standardize (bool, default=True):
142135
If True, divides the mean shifted img by the standard deviation.
143-
amin : float, optional
136+
amin (float, default=None):
144137
Clamp min value. No clamping performed if None.
145-
amax : float, optional
138+
amax (float, default=None):
146139
Clamp max value. No clamping performed if None.
147-
always_apply : bool, default=True
140+
always_apply (bool, default=True):
148141
Apply the transformation always.
149-
p : float, default=1.0
142+
p (float, default=1.0):
150143
Probability of applying the transformation.
151-
copy : bool, default=False
144+
copy (bool, default=False):
152145
If True, normalize the copy of the input.
153146
"""
154147
if not HAS_ALBU:
@@ -165,13 +158,11 @@ def __init__(
165158
def apply(self, image: np.ndarray, **kwargs) -> np.ndarray:
166159
"""Apply image-level normalization to input image.
167160
168-
Parameters
169-
----------
170-
image : np.ndarray:
161+
Parameters:
162+
image (np.ndarray):
171163
Input image to be normalized. Shape (H, W, C)|(H, W).
172164
173-
Returns
174-
-------
165+
Returns:
175166
np.ndarray:
176167
Normalized image. Same shape as input. dtype: float32.
177168
"""

cellseg_models_pytorch/transforms/functional/normalization.py

Lines changed: 40 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,22 @@ def percentile_normalize(
1515
) -> np.ndarray:
1616
"""Channelwise percentile normalization to range [0, 1].
1717
18-
Parameters
19-
----------
20-
img : np.ndarray:
18+
Parameters:
19+
img (np.ndarray):
2120
Input image to be normalized. Shape (H, W, C)|(H, W).
22-
lower : float, default=0.01:
21+
lower (float, default=0.01):
2322
The lower percentile
24-
upper : float, default=99.99:
23+
upper (float, default=99.99):
2524
The upper percentile
26-
copy : bool, default=False
25+
copy (bool, default=False):
2726
If True, normalize the copy of the input.
2827
29-
Returns
30-
-------
28+
Returns:
3129
np.ndarray:
3230
Normalized img. Same shape as input. dtype: float32.
3331
34-
Raises
35-
------
36-
ValueError
32+
Raises:
33+
ValueError:
3734
If input image does not have shape (H, W) or (H, W, C).
3835
"""
3936
axis = (0, 1)
@@ -53,34 +50,30 @@ def percentile_normalize(
5350
upercentile = np.percentile(im, upper)
5451
lpercentile = np.percentile(im, lower)
5552

56-
# return np.interp(im, (lpercentile, upercentile), axis).astype(np.float32)
57-
return np.interp(im, (lpercentile, upercentile), axis) # .astype(np.float32)
53+
return np.interp(im, (lpercentile, upercentile), axis).astype(np.float32)
5854

5955

6056
def percentile_normalize99(
6157
img: np.ndarray, amin: float = None, amax: float = None, copy: bool = False
6258
) -> np.ndarray:
6359
"""Channelwise 1-99 percentile normalization. Optional clamping.
6460
65-
Parameters
66-
----------
67-
img : np.ndarray:
61+
Parameters:
62+
img (np.ndarray)
6863
Input image to be normalized. Shape (H, W, C)|(H, W).
69-
amin : float, optional
64+
amin (float, default=None)
7065
Clamp min value. No clamping performed if None.
71-
amax : float, optional
66+
amax (float, default=None):
7267
Clamp max value. No clamping performed if None.
73-
copy : bool, default=False
68+
copy (bool, default=False):
7469
If True, normalize the copy of the input.
7570
76-
Returns
77-
-------
71+
Returns:
7872
np.ndarray:
7973
Normalized image. Same shape as input. dtype: float32.
8074
81-
Raises
82-
------
83-
ValueError
75+
Raises:
76+
ValueError:
8477
If input image does not have shape (H, W) or (H, W, C).
8578
"""
8679
axis = (0, 1)
@@ -99,7 +92,7 @@ def percentile_normalize99(
9992
percentile99 = np.percentile(im, q=99, axis=axis)
10093
num = im - percentile1
10194
denom = percentile99 - percentile1
102-
im = num / denom if denom != 0 else np.zeros_like(im)
95+
im = num / denom
10396

10497
# clamp
10598
if not any(x is None for x in (amin, amax)):
@@ -117,27 +110,24 @@ def normalize(
117110
) -> np.ndarray:
118111
"""Channelwise mean centering or standardizing of an image. Optional clamping.
119112
120-
Parameters
121-
----------
122-
img : np.ndarray
113+
Parameters:
114+
img (np.ndarray):
123115
Input image to be normalized. Shape (H, W, C)|(H, W).
124-
standardize: bool, default=True
116+
standardize (bool, default=True):
125117
If True, divide with standard deviation after mean centering
126-
amin : float, optional
118+
amin (float, default=None):
127119
Clamp min value. No clamping performed if None.
128-
amax : float, optional
120+
amax (float, default=None):
129121
Clamp max value. No clamping performed if None.
130-
copy : bool, default=False
122+
copy (bool, default=False):
131123
If True, normalize the copy of the input.
132124
133-
Returns
134-
-------
125+
Returns:
135126
np.ndarray:
136127
Normalized image. Same shape as input. dtype: float32.
137128
138-
Raises
139-
------
140-
ValueError
129+
Raises:
130+
ValueError:
141131
If input image does not have shape (H, W) or (H, W, C).
142132
"""
143133
axis = (0, 1)
@@ -157,7 +147,7 @@ def normalize(
157147

158148
if standardize:
159149
std = im.std(axis=axis, keepdims=True)
160-
im = im / std if std != 0 else np.zeros_like(im)
150+
im = np.divide(im, std, where=std != 0)
161151

162152
# clamp
163153
if not any(x is None for x in (amin, amax)):
@@ -171,25 +161,22 @@ def minmax_normalize(
171161
) -> np.ndarray:
172162
"""Min-max normalization per image channel. Optional clamping.
173163
174-
Parameters
175-
----------
176-
img : np.ndarray:
164+
Parameters:
165+
img (np.ndarray):
177166
Input image to be normalized. Shape (H, W, C)|(H, W).
178-
amin : float, optional
167+
amin (float, default=None):
179168
Clamp min value. No clamping performed if None.
180-
amax : float, optional
169+
amax (float, default=None):
181170
Clamp max value. No clamping performed if None.
182-
copy : bool, default=False
171+
copy (bool, default=False):
183172
If True, normalize the copy of the input.
184173
185-
Returns
186-
-------
174+
Returns:
187175
np.ndarray:
188176
Min-max normalized image. Same shape as input. dtype: float32.
189177
190-
Raises
191-
------
192-
ValueError
178+
Raises:
179+
ValueError;
193180
If input image does not have shape (H, W) or (H, W, C).
194181
"""
195182
if img.ndim not in (2, 3):
@@ -206,7 +193,7 @@ def minmax_normalize(
206193
max = im.max()
207194
denom = max - min
208195
num = im - min
209-
im = num / denom if denom != 0 else np.zeros_like(im)
196+
im = num / denom
210197

211198
# clamp
212199
if not any(x is None for x in (amin, amax)):
@@ -221,16 +208,14 @@ def float2ubyte(mat: np.ndarray, normalize: bool = False) -> np.ndarray:
221208
Float matrix values need to be in range [-1, 1] for img_as_ubyte so
222209
the image is normalized or clamped before conversion.
223210
224-
Parameters
225-
----------
226-
mat : np.ndarray
211+
Parameters:
212+
mat (np.ndarray):
227213
A float64 matrix. Shape (H, W, C).
228214
normalize (bool, default=False):
229215
Normalizes input to [0, 1] first. If not True,
230216
clips values between [-1, 1].
231217
232-
Returns
233-
-------
218+
Returns:
234219
np.ndarray:
235220
A uint8 matrix. Shape (H, W, C). dtype: uint8.
236221
"""

cellseg_models_pytorch/utils/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
type_map_flatten,
2929
)
3030
from .multiproc import run_pool
31-
from .tensor_utils import tensor_one_hot, to_device, to_tensor
31+
from .tensor_utils import to_device, to_tensor
3232

3333
__all__ = [
3434
"Downloader",
@@ -54,8 +54,6 @@
5454
"label_semantic",
5555
"to_tensor",
5656
"to_device",
57-
"tensor_one_hot",
58-
"normalize_torch",
5957
"draw_stuff_contours",
6058
"draw_thing_contours",
6159
"majority_vote_sequential",

0 commit comments

Comments
 (0)