diff --git a/nitransforms/manip.py b/nitransforms/manip.py index b30fd646..9389197d 100644 --- a/nitransforms/manip.py +++ b/nitransforms/manip.py @@ -213,6 +213,8 @@ def _as_chain(x): """Convert a value into a transform chain.""" if isinstance(x, TransformChain): return x.transforms + if isinstance(x, TransformBase): + return [x] if isinstance(x, Iterable): return list(x) return [x] diff --git a/nitransforms/tests/test_linear.py b/nitransforms/tests/test_linear.py index 969b33ab..31627159 100644 --- a/nitransforms/tests/test_linear.py +++ b/nitransforms/tests/test_linear.py @@ -82,6 +82,20 @@ def test_loadsave_itk(tmp_path, data_path, testdata_path): ) +def test_mapping_chain(data_path): + xfm = nitl.load(data_path / "itktflist2.tfm", fmt="itk") + xfm = nitl.load(data_path / "itktflist2.tfm", fmt="itk") + assert len(xfm) == 9 + + # Addiition produces a chain + chain = xfm + xfm + # Length now means number of transforms, not number of affines in one transform + assert len(chain) == 2 + # Just because a LinearTransformsMapping is iterable does not mean we decompose it + chain += xfm + assert len(chain) == 3 + + @pytest.mark.parametrize( "image_orientation", [