Skip to content

Commit e1edc99

Browse files
authored
more tests about the transforms with MetaTensor (#4521)
* at least scale one Signed-off-by: Wenqi Li <[email protected]> * add more tests Signed-off-by: Wenqi Li <[email protected]> * dict-based tests Signed-off-by: Wenqi Li <[email protected]> * reviewed intensity transforms Signed-off-by: Wenqi Li <[email protected]> * review post Signed-off-by: Wenqi Li <[email protected]> * review utils Signed-off-by: Wenqi Li <[email protected]> * integration tests Signed-off-by: Wenqi Li <[email protected]> * adds pixdim property, fixes typos Signed-off-by: Wenqi Li <[email protected]>
1 parent 03eacf1 commit e1edc99

35 files changed

+312
-89
lines changed

monai/data/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
from .thread_buffer import ThreadBuffer, ThreadDataLoader
7272
from .torchscript_utils import load_net_with_metadata, save_net_with_metadata
7373
from .utils import (
74+
affine_to_spacing,
7475
compute_importance_map,
7576
compute_shape_offset,
7677
convert_tables_to_dicts,

monai/data/meta_tensor.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from monai.config.type_definitions import NdarrayTensor
2121
from monai.data.meta_obj import MetaObj, get_track_meta
22-
from monai.data.utils import decollate_batch, list_data_collate, remove_extra_metadata
22+
from monai.data.utils import affine_to_spacing, decollate_batch, list_data_collate, remove_extra_metadata
2323
from monai.utils.enums import PostFix
2424
from monai.utils.type_conversion import convert_to_tensor
2525

@@ -304,13 +304,18 @@ def as_dict(self, key: str) -> dict:
304304
@property
305305
def affine(self) -> torch.Tensor:
306306
"""Get the affine."""
307-
return self.meta["affine"] # type: ignore
307+
return self.meta.get("affine", self.get_default_affine()) # type: ignore
308308

309309
@affine.setter
310310
def affine(self, d: NdarrayTensor) -> None:
311311
"""Set the affine."""
312312
self.meta["affine"] = torch.as_tensor(d, device=self.device)
313313

314+
@property
315+
def pixdim(self):
316+
"""Get the spacing"""
317+
return affine_to_spacing(self.affine)
318+
314319
def new_empty(self, size, dtype=None, device=None, requires_grad=False):
315320
"""
316321
must be defined for deepcopy to work

monai/transforms/intensity/array.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
"HistogramNormalize",
7272
"IntensityRemap",
7373
"RandIntensityRemap",
74+
"ForegroundMask",
7475
]
7576

7677

monai/transforms/post/dictionary.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -645,17 +645,17 @@ def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]:
645645
trans_info=deepcopy(transform_info), mode="nearest", align_corners=None
646646
)
647647

648-
input = d[key]
649-
if isinstance(input, torch.Tensor):
650-
input = input.detach()
648+
inputs = d[key]
649+
if isinstance(inputs, torch.Tensor):
650+
inputs = inputs.detach()
651651

652-
if not isinstance(input, MetaTensor):
653-
input = MetaTensor(input)
654-
input.applied_operations = transform_info
655-
input.meta = meta_info
652+
if not isinstance(inputs, MetaTensor):
653+
inputs = MetaTensor(inputs)
654+
inputs.applied_operations = transform_info
655+
inputs.meta = meta_info
656656

657657
# construct the input dict data
658-
input_dict = {orig_key: input}
658+
input_dict = {orig_key: inputs}
659659

660660
with allow_missing_keys_mode(self.transform): # type: ignore
661661
inverted = self.transform.inverse(input_dict)

monai/transforms/spatial/array.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -345,10 +345,10 @@ class ResampleToMatch(SpatialResample):
345345
and the size of the output image will match."""
346346

347347
@deprecated_arg(
348-
name="src_meta", since="0.8", msg_suffix="img should be `MetaTensor`, so affine can be extacted directly."
348+
name="src_meta", since="0.8", msg_suffix="img should be `MetaTensor`, so affine can be extracted directly."
349349
)
350350
@deprecated_arg(
351-
name="dst_meta", since="0.8", msg_suffix="img_dst should be `MetaTensor`, so affine can be extacted directly."
351+
name="dst_meta", since="0.8", msg_suffix="img_dst should be `MetaTensor`, so affine can be extracted directly."
352352
)
353353
def __call__( # type: ignore
354354
self,

monai/transforms/utility/dictionary.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1011,6 +1011,7 @@ class Lambdad(MapTransform, InvertibleTransform):
10111011
print(lambd(input_data)['label'].shape)
10121012
(4, 2, 2)
10131013
1014+
10141015
Args:
10151016
keys: keys of the corresponding items to be transformed.
10161017
See also: :py:class:`monai.transforms.compose.MapTransform`

monai/transforms/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1589,7 +1589,7 @@ def scale_affine(affine, spatial_size, new_spatial_size, centered: bool = True):
15891589
if spatial_size == new_spatial_size:
15901590
return affine
15911591
r = len(affine) - 1
1592-
s = np.array([float(o) / max(n, 0) for o, n in zip(spatial_size, new_spatial_size)])
1592+
s = np.array([float(o) / float(max(n, 1)) for o, n in zip(spatial_size, new_spatial_size)])
15931593
scale = create_scale(r, s.tolist())
15941594
if centered:
15951595
scale[:r, -1] = (np.diag(scale)[:r] - 1) / 2 # type: ignore

tests/test_activations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@
7676

7777

7878
class TestActivations(unittest.TestCase):
79-
@parameterized.expand(TEST_CASES[:3])
79+
@parameterized.expand(TEST_CASES)
8080
def test_value_shape(self, input_param, img, out, expected_shape):
8181
result = Activations(**input_param)(img)
8282

tests/test_adjust_contrast.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ class TestAdjustContrast(NumpyImageTestCase2D):
2929
def test_correct_results(self, gamma):
3030
adjuster = AdjustContrast(gamma=gamma)
3131
for p in TEST_NDARRAYS:
32-
result = adjuster(p(self.imt))
32+
im = p(self.imt)
33+
result = adjuster(im)
34+
self.assertTrue(type(im), type(result))
3335
if gamma == 1.0:
3436
expected = self.imt
3537
else:

tests/test_as_channel_last.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class TestAsChannelLast(unittest.TestCase):
2929
def test_shape(self, in_type, input_param, expected_shape):
3030
test_data = in_type(np.random.randint(0, 2, size=[1, 2, 3, 4]))
3131
result = AsChannelLast(**input_param)(test_data)
32+
self.assertEqual(type(result), type(test_data))
3233
self.assertTupleEqual(result.shape, expected_shape)
3334

3435

0 commit comments

Comments
 (0)