Skip to content

Commit 500a4b2

Browse files
committed
fix: shorter implementation of set transform
1 parent 58b8e00 commit 500a4b2

File tree

1 file changed

+4
-31
lines changed

1 file changed

+4
-31
lines changed

src/nifreeze/data/dmri.py

Lines changed: 4 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -99,37 +99,14 @@ def set_transform(self, index: int, affine: np.ndarray, order: int = 3) -> None:
9999
ValueError
100100
If ``gradients`` is None or doesn't match the data shape.
101101
"""
102-
# Basic validation
103-
if self.gradients is None:
104-
raise ValueError("Cannot set a transform on DWI data without a gradient table.")
105-
106-
# If the gradient table is Nx4, and dataobj has shape (..., N)
107-
n_volumes = self.dataobj.shape[-1]
108-
if self.gradients.shape[0] != n_volumes and self.gradients.shape[1] == n_volumes:
109-
# Possibly transposed gradient table - handle or raise an error
110-
raise ValueError("Gradient table shape does not match the data's last dimension.")
102+
if not Path(self._filepath).exists():
103+
self.to_filename(self._filepath)
111104

112105
reference = namedtuple("ImageGrid", ("shape", "affine"))(
113106
shape=self.dataobj.shape[:3], affine=self.affine
114107
)
115108
xform = Affine(matrix=affine, reference=reference)
116-
117-
if not Path(self._filepath).exists():
118-
self.to_filename(self._filepath)
119-
120-
# read original DWI data & b-vector
121-
with h5py.File(self._filepath, "r") as in_file:
122-
root = in_file["/0"]
123-
dwi_frame = np.asanyarray(root["dataobj"][..., index])
124-
bvec = np.asanyarray(root["gradients"][index, :3])
125-
126-
dwmoving = nb.Nifti1Image(dwi_frame, self.affine, None)
127-
128-
# resample and update orientation at index
129-
self.dataobj[..., index] = np.asanyarray(
130-
xform.apply(dwmoving, order=order).dataobj,
131-
dtype=self.dataobj.dtype,
132-
)
109+
bvec = self.gradients[index, :3]
133110

134111
# invert transform transform b-vector and origin
135112
r_bvec = (~xform).map([bvec, (0.0, 0.0, 0.0)])
@@ -138,11 +115,7 @@ def set_transform(self, index: int, affine: np.ndarray, order: int = 3) -> None:
138115
# Normalize and update
139116
self.gradients[index, :3] = new_bvec / np.linalg.norm(new_bvec)
140117

141-
# update transform
142-
if self.em_affines is None:
143-
self.em_affines = np.zeros((self.dataobj.shape[-1], 4, 4))
144-
145-
self.em_affines[index] = xform.matrix
118+
super().set_transform(index, affine, order)
146119

147120
@classmethod
148121
def from_filename(cls, filename: Path | str) -> DWI:

0 commit comments

Comments
 (0)