Skip to content

Commit 453a9c0

Browse files
authored
Torch GaussianSmooth, RandGaussianSmooth, GaussianSharpen, RandGaussianSharpen (#2971)
Torch `GaussianSmooth`, `RandGaussianSmooth`, `GaussianSharpen`, `RandGaussianSharpen`
1 parent a9cd2d8 commit 453a9c0

11 files changed

+617
-320
lines changed

monai/networks/layers/simplelayers.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# limitations under the License.
1111

1212
import math
13+
from copy import deepcopy
1314
from typing import List, Sequence, Union
1415

1516
import torch
@@ -24,10 +25,10 @@
2425
ChannelMatching,
2526
InvalidPyTorchVersionError,
2627
SkipMode,
27-
ensure_tuple_rep,
2828
look_up_option,
2929
optional_import,
3030
)
31+
from monai.utils.misc import issequenceiterable
3132

3233
_C, _ = optional_import("monai._C")
3334
if not PT_BEFORE_1_7:
@@ -393,13 +394,18 @@ def __init__(
393394
(for example `parameters()` iterator could be used to get the parameters);
394395
otherwise this module will fix the kernels using `sigma` as the std.
395396
"""
397+
if issequenceiterable(sigma):
398+
if len(sigma) != spatial_dims: # type: ignore
399+
raise ValueError
400+
else:
401+
sigma = [deepcopy(sigma) for _ in range(spatial_dims)] # type: ignore
396402
super().__init__()
397403
self.sigma = [
398404
torch.nn.Parameter(
399405
torch.as_tensor(s, dtype=torch.float, device=s.device if isinstance(s, torch.Tensor) else None),
400406
requires_grad=requires_grad,
401407
)
402-
for s in ensure_tuple_rep(sigma, int(spatial_dims))
408+
for s in sigma # type: ignore
403409
]
404410
self.truncated = truncated
405411
self.approx = approx

monai/transforms/intensity/array.py

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,15 +1030,24 @@ class GaussianSmooth(Transform):
10301030
10311031
"""
10321032

1033+
backend = [TransformBackends.TORCH]
1034+
10331035
def __init__(self, sigma: Union[Sequence[float], float] = 1.0, approx: str = "erf") -> None:
10341036
self.sigma = sigma
10351037
self.approx = approx
10361038

1037-
def __call__(self, img: np.ndarray):
1038-
img, *_ = convert_data_type(img, np.ndarray) # type: ignore
1039-
gaussian_filter = GaussianFilter(img.ndim - 1, self.sigma, approx=self.approx)
1040-
input_data = torch.as_tensor(np.ascontiguousarray(img), dtype=torch.float).unsqueeze(0)
1041-
return gaussian_filter(input_data).squeeze(0).detach().numpy()
1039+
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
1040+
img_t: torch.Tensor
1041+
img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float) # type: ignore
1042+
sigma: Union[Sequence[torch.Tensor], torch.Tensor]
1043+
if isinstance(self.sigma, Sequence):
1044+
sigma = [torch.as_tensor(s, device=img_t.device) for s in self.sigma]
1045+
else:
1046+
sigma = torch.as_tensor(self.sigma, device=img_t.device)
1047+
gaussian_filter = GaussianFilter(img_t.ndim - 1, sigma, approx=self.approx)
1048+
out_t: torch.Tensor = gaussian_filter(img_t.unsqueeze(0)).squeeze(0)
1049+
out, *_ = convert_data_type(out_t, type(img), device=img.device if isinstance(img, torch.Tensor) else None)
1050+
return out
10421051

10431052

10441053
class RandGaussianSmooth(RandomizableTransform):
@@ -1079,10 +1088,10 @@ def randomize(self, data: Optional[Any] = None) -> None:
10791088
self.y = self.R.uniform(low=self.sigma_y[0], high=self.sigma_y[1])
10801089
self.z = self.R.uniform(low=self.sigma_z[0], high=self.sigma_z[1])
10811090

1082-
def __call__(self, img: np.ndarray):
1083-
img, *_ = convert_data_type(img, np.ndarray) # type: ignore
1091+
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
10841092
self.randomize()
10851093
if not self._do_transform:
1094+
img, *_ = convert_data_type(img, dtype=torch.float)
10861095
return img
10871096
sigma = ensure_tuple_size(tup=(self.x, self.y, self.z), dim=img.ndim - 1)
10881097
return GaussianSmooth(sigma=sigma, approx=self.approx)(img)
@@ -1115,6 +1124,8 @@ class GaussianSharpen(Transform):
11151124
11161125
"""
11171126

1127+
backend = [TransformBackends.TORCH]
1128+
11181129
def __init__(
11191130
self,
11201131
sigma1: Union[Sequence[float], float] = 3.0,
@@ -1127,14 +1138,19 @@ def __init__(
11271138
self.alpha = alpha
11281139
self.approx = approx
11291140

1130-
def __call__(self, img: np.ndarray):
1131-
img, *_ = convert_data_type(img, np.ndarray) # type: ignore
1132-
gaussian_filter1 = GaussianFilter(img.ndim - 1, self.sigma1, approx=self.approx)
1133-
gaussian_filter2 = GaussianFilter(img.ndim - 1, self.sigma2, approx=self.approx)
1134-
input_data = torch.as_tensor(np.ascontiguousarray(img), dtype=torch.float).unsqueeze(0)
1135-
blurred_f = gaussian_filter1(input_data)
1136-
filter_blurred_f = gaussian_filter2(blurred_f)
1137-
return (blurred_f + self.alpha * (blurred_f - filter_blurred_f)).squeeze(0).detach().numpy()
1141+
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
1142+
img_t: torch.Tensor
1143+
img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float32) # type: ignore
1144+
1145+
gf1, gf2 = [
1146+
GaussianFilter(img_t.ndim - 1, sigma, approx=self.approx).to(img_t.device)
1147+
for sigma in (self.sigma1, self.sigma2)
1148+
]
1149+
blurred_f = gf1(img_t.unsqueeze(0))
1150+
filter_blurred_f = gf2(blurred_f)
1151+
out_t: torch.Tensor = (blurred_f + self.alpha * (blurred_f - filter_blurred_f)).squeeze(0)
1152+
out, *_ = convert_data_type(out_t, type(img), device=img.device if isinstance(img, torch.Tensor) else None)
1153+
return out
11381154

11391155

11401156
class RandGaussianSharpen(RandomizableTransform):
@@ -1159,6 +1175,8 @@ class RandGaussianSharpen(RandomizableTransform):
11591175
11601176
"""
11611177

1178+
backend = GaussianSharpen.backend
1179+
11621180
def __init__(
11631181
self,
11641182
sigma1_x: Tuple[float, float] = (0.5, 1.0),
@@ -1194,10 +1212,11 @@ def randomize(self, data: Optional[Any] = None) -> None:
11941212
self.z2 = self.R.uniform(low=sigma2_z[0], high=sigma2_z[1])
11951213
self.a = self.R.uniform(low=self.alpha[0], high=self.alpha[1])
11961214

1197-
def __call__(self, img: np.ndarray):
1198-
img, *_ = convert_data_type(img, np.ndarray) # type: ignore
1215+
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
11991216
self.randomize()
1217+
# if not doing, just need to convert to tensor
12001218
if not self._do_transform:
1219+
img, *_ = convert_data_type(img, dtype=torch.float32)
12011220
return img
12021221
sigma1 = ensure_tuple_size(tup=(self.x1, self.y1, self.z1), dim=img.ndim - 1)
12031222
sigma2 = ensure_tuple_size(tup=(self.x2, self.y2, self.z2), dim=img.ndim - 1)

monai/transforms/intensity/dictionary.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform
4949
from monai.transforms.utils import is_positive
5050
from monai.utils import convert_to_dst_type, ensure_tuple, ensure_tuple_rep, ensure_tuple_size
51+
from monai.utils.type_conversion import convert_data_type
5152

5253
__all__ = [
5354
"RandGaussianNoised",
@@ -897,6 +898,8 @@ class GaussianSmoothd(MapTransform):
897898
898899
"""
899900

901+
backend = GaussianSmooth.backend
902+
900903
def __init__(
901904
self,
902905
keys: KeysCollection,
@@ -907,7 +910,7 @@ def __init__(
907910
super().__init__(keys, allow_missing_keys)
908911
self.converter = GaussianSmooth(sigma, approx=approx)
909912

910-
def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
913+
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
911914
d = dict(data)
912915
for key in self.key_iterator(d):
913916
d[key] = self.converter(d[key])
@@ -931,6 +934,8 @@ class RandGaussianSmoothd(RandomizableTransform, MapTransform):
931934
932935
"""
933936

937+
backend = GaussianSmooth.backend
938+
934939
def __init__(
935940
self,
936941
keys: KeysCollection,
@@ -954,14 +959,15 @@ def randomize(self, data: Optional[Any] = None) -> None:
954959
self.y = self.R.uniform(low=self.sigma_y[0], high=self.sigma_y[1])
955960
self.z = self.R.uniform(low=self.sigma_z[0], high=self.sigma_z[1])
956961

957-
def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
962+
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
958963
d = dict(data)
959964
self.randomize()
960-
if not self._do_transform:
961-
return d
962965
for key in self.key_iterator(d):
963-
sigma = ensure_tuple_size(tup=(self.x, self.y, self.z), dim=d[key].ndim - 1)
964-
d[key] = GaussianSmooth(sigma=sigma, approx=self.approx)(d[key])
966+
if self._do_transform:
967+
sigma = ensure_tuple_size(tup=(self.x, self.y, self.z), dim=d[key].ndim - 1)
968+
d[key] = GaussianSmooth(sigma=sigma, approx=self.approx)(d[key])
969+
else:
970+
d[key], *_ = convert_data_type(d[key], torch.Tensor, dtype=torch.float)
965971
return d
966972

967973

@@ -985,6 +991,8 @@ class GaussianSharpend(MapTransform):
985991
986992
"""
987993

994+
backend = GaussianSharpen.backend
995+
988996
def __init__(
989997
self,
990998
keys: KeysCollection,
@@ -997,7 +1005,7 @@ def __init__(
9971005
super().__init__(keys, allow_missing_keys)
9981006
self.converter = GaussianSharpen(sigma1, sigma2, alpha, approx=approx)
9991007

1000-
def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
1008+
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
10011009
d = dict(data)
10021010
for key in self.key_iterator(d):
10031011
d[key] = self.converter(d[key])
@@ -1028,6 +1036,8 @@ class RandGaussianSharpend(RandomizableTransform, MapTransform):
10281036
10291037
"""
10301038

1039+
backend = GaussianSharpen.backend
1040+
10311041
def __init__(
10321042
self,
10331043
keys: KeysCollection,
@@ -1066,15 +1076,17 @@ def randomize(self, data: Optional[Any] = None) -> None:
10661076
self.z2 = self.R.uniform(low=sigma2_z[0], high=sigma2_z[1])
10671077
self.a = self.R.uniform(low=self.alpha[0], high=self.alpha[1])
10681078

1069-
def __call__(self, data):
1079+
def __call__(self, data: Dict[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
10701080
d = dict(data)
10711081
self.randomize()
1072-
if not self._do_transform:
1073-
return d
10741082
for key in self.key_iterator(d):
1075-
sigma1 = ensure_tuple_size(tup=(self.x1, self.y1, self.z1), dim=d[key].ndim - 1)
1076-
sigma2 = ensure_tuple_size(tup=(self.x2, self.y2, self.z2), dim=d[key].ndim - 1)
1077-
d[key] = GaussianSharpen(sigma1=sigma1, sigma2=sigma2, alpha=self.a, approx=self.approx)(d[key])
1083+
if self._do_transform:
1084+
sigma1 = ensure_tuple_size(tup=(self.x1, self.y1, self.z1), dim=d[key].ndim - 1)
1085+
sigma2 = ensure_tuple_size(tup=(self.x2, self.y2, self.z2), dim=d[key].ndim - 1)
1086+
d[key] = GaussianSharpen(sigma1=sigma1, sigma2=sigma2, alpha=self.a, approx=self.approx)(d[key])
1087+
else:
1088+
# if not doing the transform, convert to torch
1089+
d[key], *_ = convert_data_type(d[key], torch.Tensor, dtype=torch.float32)
10781090
return d
10791091

10801092

tests/test_gaussian_sharpen.py

Lines changed: 56 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -11,50 +11,79 @@
1111

1212
import unittest
1313

14-
import numpy as np
1514
from parameterized import parameterized
1615

1716
from monai.transforms import GaussianSharpen
17+
from tests.utils import TEST_NDARRAYS, assert_allclose
1818

19-
TEST_CASE_1 = [
20-
{},
21-
np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),
22-
np.array(
19+
TESTS = []
20+
21+
for p in TEST_NDARRAYS:
22+
TESTS.append(
2323
[
24-
[[4.1081963, 3.4950666, 4.1081963], [3.7239995, 2.8491793, 3.7239995], [4.569839, 3.9529324, 4.569839]],
25-
[[10.616725, 9.081067, 10.616725], [9.309998, 7.12295, 9.309998], [11.078365, 9.538931, 11.078365]],
24+
{},
25+
p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),
26+
p(
27+
[
28+
[
29+
[4.1081963, 3.4950666, 4.1081963],
30+
[3.7239995, 2.8491793, 3.7239995],
31+
[4.569839, 3.9529324, 4.569839],
32+
],
33+
[[10.616725, 9.081067, 10.616725], [9.309998, 7.12295, 9.309998], [11.078365, 9.538931, 11.078365]],
34+
]
35+
),
2636
]
27-
),
28-
]
37+
)
2938

30-
TEST_CASE_2 = [
31-
{"sigma1": 1.0, "sigma2": 0.75, "alpha": 20},
32-
np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),
33-
np.array(
39+
TESTS.append(
3440
[
35-
[[4.513644, 4.869134, 4.513644], [8.467242, 9.4004135, 8.467242], [10.416813, 12.0653515, 10.416813]],
36-
[[15.711488, 17.569994, 15.711488], [21.16811, 23.501041, 21.16811], [21.614658, 24.766209, 21.614658]],
41+
{"sigma1": 1.0, "sigma2": 0.75, "alpha": 20},
42+
p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),
43+
p(
44+
[
45+
[
46+
[4.513644, 4.869134, 4.513644],
47+
[8.467242, 9.4004135, 8.467242],
48+
[10.416813, 12.0653515, 10.416813],
49+
],
50+
[
51+
[15.711488, 17.569994, 15.711488],
52+
[21.16811, 23.501041, 21.16811],
53+
[21.614658, 24.766209, 21.614658],
54+
],
55+
]
56+
),
3757
]
38-
),
39-
]
58+
)
4059

41-
TEST_CASE_3 = [
42-
{"sigma1": (0.5, 1.0), "sigma2": (0.5, 0.75), "alpha": 20},
43-
np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),
44-
np.array(
60+
TESTS.append(
4561
[
46-
[[3.3324685, 3.335536, 3.3324673], [7.7666636, 8.16056, 7.7666636], [12.662973, 14.317837, 12.6629715]],
47-
[[15.329051, 16.57557, 15.329051], [19.41665, 20.40139, 19.416655], [24.659554, 27.557873, 24.659554]],
62+
{"sigma1": (0.5, 1.0), "sigma2": (0.5, 0.75), "alpha": 20},
63+
p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),
64+
p(
65+
[
66+
[
67+
[3.3324685, 3.335536, 3.3324673],
68+
[7.7666636, 8.16056, 7.7666636],
69+
[12.662973, 14.317837, 12.6629715],
70+
],
71+
[
72+
[15.329051, 16.57557, 15.329051],
73+
[19.41665, 20.40139, 19.416655],
74+
[24.659554, 27.557873, 24.659554],
75+
],
76+
]
77+
),
4878
]
49-
),
50-
]
79+
)
5180

5281

5382
class TestGaussianSharpen(unittest.TestCase):
54-
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
83+
@parameterized.expand(TESTS)
5584
def test_value(self, argments, image, expected_data):
5685
result = GaussianSharpen(**argments)(image)
57-
np.testing.assert_allclose(result, expected_data, rtol=1e-4)
86+
assert_allclose(result, expected_data, atol=0, rtol=1e-4, type_test=False)
5887

5988

6089
if __name__ == "__main__":

0 commit comments

Comments
 (0)