|
9 | 9 | """Nonlinear transforms."""
|
10 | 10 | import warnings
|
11 | 11 | from functools import partial
|
| 12 | +from collections import namedtuple |
12 | 13 | import numpy as np
|
13 | 14 |
|
14 | 15 | from nitransforms import io
|
@@ -227,17 +228,54 @@ def __eq__(self, other):
|
227 | 228 | warnings.warn("Fields are equal, but references do not match.")
|
228 | 229 | return _eq
|
229 | 230 |
|
| 231 | + def to_x5(self, metadata=None): |
| 232 | + """Return an :class:`~nitransforms.io.x5.X5Transform` representation.""" |
| 233 | + from ._version import __version__ |
| 234 | + from .io.x5 import X5Domain, X5Transform |
| 235 | + |
| 236 | + metadata = {"WrittenBy": f"NiTransforms {__version__}"} | (metadata or {}) |
| 237 | + |
| 238 | + domain = None |
| 239 | + if (reference := self.reference) is not None: |
| 240 | + domain = X5Domain( |
| 241 | + grid=True, |
| 242 | + size=getattr(reference, "shape", (0, 0, 0)), |
| 243 | + mapping=reference.affine, |
| 244 | + coordinates="cartesian", |
| 245 | + ) |
| 246 | + |
| 247 | + kinds = tuple("space" for _ in range(self.ndim)) + ("vector",) |
| 248 | + |
| 249 | + return X5Transform( |
| 250 | + type="nonlinear", |
| 251 | + subtype="densefield", |
| 252 | + representation="displacements", |
| 253 | + metadata=metadata, |
| 254 | + transform=self._deltas, |
| 255 | + dimension_kinds=kinds, |
| 256 | + domain=domain, |
| 257 | + ) |
| 258 | + |
230 | 259 | @classmethod
|
231 | 260 | def from_filename(cls, filename, fmt="X5"):
|
232 | 261 | _factory = {
|
233 | 262 | "afni": io.afni.AFNIDisplacementsField,
|
234 | 263 | "itk": io.itk.ITKDisplacementsField,
|
235 | 264 | "fsl": io.fsl.FSLDisplacementsField,
|
| 265 | + "X5": None, |
236 | 266 | }
|
237 |
| - if fmt not in _factory: |
| 267 | + fmt = fmt.upper() |
| 268 | + if fmt not in {k.upper() for k in _factory}: |
238 | 269 | raise NotImplementedError(f"Unsupported format <{fmt}>")
|
239 | 270 |
|
240 |
| - return cls(_factory[fmt].from_filename(filename)) |
| 271 | + if fmt == "X5": |
| 272 | + from .io.x5 import from_filename as load_x5 |
| 273 | + x5_xfm = load_x5(filename)[0] |
| 274 | + Domain = namedtuple("Domain", "affine shape") |
| 275 | + reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size) |
| 276 | + return cls(x5_xfm.transform, is_deltas=True, reference=reference) |
| 277 | + |
| 278 | + return cls(_factory[fmt.lower()].from_filename(filename)) |
241 | 279 |
|
242 | 280 |
|
243 | 281 | load = DenseFieldTransform.from_filename
|
@@ -293,6 +331,39 @@ def to_field(self, reference=None, dtype="float32"):
|
293 | 331 | field.astype(dtype).reshape(*_ref.shape, -1), reference=_ref
|
294 | 332 | )
|
295 | 333 |
|
| 334 | + def to_x5(self, metadata=None): |
| 335 | + """Return an :class:`~nitransforms.io.x5.X5Transform` representation.""" |
| 336 | + from ._version import __version__ |
| 337 | + from .io.x5 import X5Transform, X5Domain |
| 338 | + |
| 339 | + metadata = {"WrittenBy": f"NiTransforms {__version__}"} | (metadata or {}) |
| 340 | + |
| 341 | + domain = None |
| 342 | + if (reference := self.reference) is not None: |
| 343 | + domain = X5Domain( |
| 344 | + grid=True, |
| 345 | + size=getattr(reference, "shape", (0, 0, 0)), |
| 346 | + mapping=reference.affine, |
| 347 | + coordinates="cartesian", |
| 348 | + ) |
| 349 | + |
| 350 | + meta = metadata | { |
| 351 | + "KnotsAffine": self._knots.affine.tolist(), |
| 352 | + "KnotsShape": self._knots.shape, |
| 353 | + } |
| 354 | + |
| 355 | + kinds = tuple("space" for _ in range(self.ndim)) + ("vector",) |
| 356 | + |
| 357 | + return X5Transform( |
| 358 | + type="nonlinear", |
| 359 | + subtype="bspline", |
| 360 | + representation="coefficients", |
| 361 | + metadata=meta, |
| 362 | + transform=self._coeffs, |
| 363 | + dimension_kinds=kinds, |
| 364 | + domain=domain, |
| 365 | + ) |
| 366 | + |
296 | 367 | def map(self, x, inverse=False):
|
297 | 368 | r"""
|
298 | 369 | Apply the transformation to a list of physical coordinate points.
|
|
0 commit comments