Skip to content

Commit 9ddc9e6

Browse files
authored
3313 Add affine arg to the dict transform (#3326)
* [DLMED] add affine to dict transform Signed-off-by: Nic Ma <[email protected]> * [DLMED] add unit tests Signed-off-by: Nic Ma <[email protected]>
1 parent d35759a commit 9ddc9e6

File tree

4 files changed

+59
-0
lines changed

4 files changed

+59
-0
lines changed

monai/transforms/spatial/dictionary.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,6 +605,7 @@ def __init__(
605605
shear_params: Optional[Union[Sequence[float], float]] = None,
606606
translate_params: Optional[Union[Sequence[float], float]] = None,
607607
scale_params: Optional[Union[Sequence[float], float]] = None,
608+
affine: Optional[NdarrayOrTensor] = None,
608609
spatial_size: Optional[Union[Sequence[int], int]] = None,
609610
mode: GridSampleModeSequence = GridSampleMode.BILINEAR,
610611
padding_mode: GridSamplePadModeSequence = GridSamplePadMode.REFLECTION,
@@ -631,6 +632,9 @@ def __init__(
631632
pixel/voxel relative to the center of the input image. Defaults to no translation.
632633
scale_params: scale factor for every spatial dims. a tuple of 2 floats for 2D,
633634
a tuple of 3 floats for 3D. Defaults to `1.0`.
635+
affine: if applied, ignore the params (`rotate_params`, etc.) and use the
636+
supplied matrix. Should be square with each side = num of image spatial
637+
dimensions + 1.
634638
spatial_size: output image spatial size.
635639
if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1,
636640
the transform will use the spatial size of `img`.
@@ -662,6 +666,7 @@ def __init__(
662666
shear_params=shear_params,
663667
translate_params=translate_params,
664668
scale_params=scale_params,
669+
affine=affine,
665670
spatial_size=spatial_size,
666671
device=device,
667672
)

tests/test_affine.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,17 @@
5656
p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])),
5757
]
5858
)
59+
TESTS.append(
60+
[
61+
dict(
62+
affine=p(torch.tensor([[0.0, -1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]])),
63+
padding_mode="zeros",
64+
device=device,
65+
),
66+
{"img": p(np.arange(4).reshape((1, 2, 2))), "spatial_size": (4, 4)},
67+
p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])),
68+
]
69+
)
5970
TESTS.append(
6071
[
6172
dict(padding_mode="zeros", device=device),

tests/test_affine_grid.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,35 @@
5454
),
5555
]
5656
)
57+
TESTS.append(
58+
[
59+
{
60+
"affine": p(
61+
torch.tensor(
62+
[[-10.8060, -8.4147, 0.0000], [-16.8294, 5.4030, 0.0000], [0.0000, 0.0000, 1.0000]]
63+
)
64+
)
65+
},
66+
{"grid": p(torch.ones((3, 3, 3)))},
67+
p(
68+
torch.tensor(
69+
[
70+
[
71+
[-19.2208, -19.2208, -19.2208],
72+
[-19.2208, -19.2208, -19.2208],
73+
[-19.2208, -19.2208, -19.2208],
74+
],
75+
[
76+
[-11.4264, -11.4264, -11.4264],
77+
[-11.4264, -11.4264, -11.4264],
78+
[-11.4264, -11.4264, -11.4264],
79+
],
80+
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
81+
]
82+
)
83+
),
84+
]
85+
)
5786
TESTS.append(
5887
[
5988
{"rotate_params": (1.0, 1.0, 1.0), "scale_params": (-20, 10), "device": device},
@@ -99,6 +128,7 @@
99128
]
100129
)
101130

131+
102132
_rtol = 5e-2 if is_tf32_env() else 1e-4
103133

104134

tests/test_affined.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,19 @@
4949
p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])),
5050
]
5151
)
52+
TESTS.append(
53+
[
54+
dict(
55+
keys="img",
56+
affine=p(torch.tensor([[0.0, -1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]])),
57+
padding_mode="zeros",
58+
spatial_size=(4, 4),
59+
device=device,
60+
),
61+
{"img": p(np.arange(4).reshape((1, 2, 2)))},
62+
p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])),
63+
]
64+
)
5265
TESTS.append(
5366
[
5467
dict(keys="img", padding_mode="zeros", spatial_size=(-1, 0, 0), device=device),

0 commit comments

Comments
 (0)