diff --git a/nitransforms/tests/test_resampling.py b/nitransforms/tests/test_resampling.py index f9b8857..d3345ee 100644 --- a/nitransforms/tests/test_resampling.py +++ b/nitransforms/tests/test_resampling.py @@ -330,6 +330,30 @@ def test_apply_transformchain(tmp_path, testdata_path): assert (np.abs(diff) > 1e-3).sum() / diff.size < RMSE_TOL_LINEAR +@pytest.mark.xfail(reason="gh-281: applying a single 3D transform to 4D data") +def test_apply_single_3d_on_4d(): + """Apply one 3D transform across all timepoints of a 4D dataset.""" + nvols = 5 + data = np.zeros((10, 5, 5, nvols), dtype=np.float32) + for i in range(nvols): + data[i + 1, 2, 2, i] = i + 1 + + img = nb.Nifti1Image(data, np.eye(4)) + + mat = np.eye(4) + mat[0, 3] = -1.0 + ref = nb.Nifti1Image(np.zeros((10, 5, 5), dtype=np.uint8), np.eye(4)) + xfm = nitl.Affine(mat, reference=ref) + + moved = apply(xfm, img, order=0) + moved_data = np.asanyarray(moved.dataobj) + + assert moved_data.shape == data.shape + for i in range(nvols): + assert moved_data[i + 2, 2, 2, i] == i + 1 + assert moved_data[i + 1, 2, 2, i] == 0 + + @pytest.mark.parametrize("serialize_4d", [True, False]) def test_LinearTransformsMapping_apply( tmp_path, data_path, testdata_path, serialize_4d