|
9 | 9 | """Nonlinear transforms."""
|
10 | 10 | import warnings
|
11 | 11 | from functools import partial
|
| 12 | +from collections import namedtuple |
12 | 13 | import numpy as np
|
13 | 14 |
|
14 | 15 | from nitransforms import io
|
@@ -229,17 +230,54 @@ def __eq__(self, other):
|
229 | 230 | warnings.warn("Fields are equal, but references do not match.")
|
230 | 231 | return _eq
|
231 | 232 |
|
| 233 | + def to_x5(self, metadata=None): |
| 234 | + """Return an :class:`~nitransforms.io.x5.X5Transform` representation.""" |
| 235 | + from ._version import __version__ |
| 236 | + from .io.x5 import X5Domain, X5Transform |
| 237 | + |
| 238 | + metadata = {"WrittenBy": f"NiTransforms {__version__}"} | (metadata or {}) |
| 239 | + |
| 240 | + domain = None |
| 241 | + if (reference := self.reference) is not None: |
| 242 | + domain = X5Domain( |
| 243 | + grid=True, |
| 244 | + size=getattr(reference, "shape", (0, 0, 0)), |
| 245 | + mapping=reference.affine, |
| 246 | + coordinates="cartesian", |
| 247 | + ) |
| 248 | + |
| 249 | + kinds = tuple("space" for _ in range(self.ndim)) + ("vector",) |
| 250 | + |
| 251 | + return X5Transform( |
| 252 | + type="nonlinear", |
| 253 | + subtype="densefield", |
| 254 | + representation="displacements", |
| 255 | + metadata=metadata, |
| 256 | + transform=self._deltas, |
| 257 | + dimension_kinds=kinds, |
| 258 | + domain=domain, |
| 259 | + ) |
| 260 | + |
232 | 261 | @classmethod
|
233 | 262 | def from_filename(cls, filename, fmt="X5"):
|
234 | 263 | _factory = {
|
235 | 264 | "afni": io.afni.AFNIDisplacementsField,
|
236 | 265 | "itk": io.itk.ITKDisplacementsField,
|
237 | 266 | "fsl": io.fsl.FSLDisplacementsField,
|
| 267 | + "X5": None, |
238 | 268 | }
|
239 |
| - if fmt not in _factory: |
| 269 | + fmt = fmt.upper() |
| 270 | + if fmt not in {k.upper() for k in _factory}: |
240 | 271 | raise NotImplementedError(f"Unsupported format <{fmt}>")
|
241 | 272 |
|
242 |
| - return cls(_factory[fmt].from_filename(filename)) |
| 273 | + if fmt == "X5": |
| 274 | + from .io.x5 import from_filename as load_x5 |
| 275 | + x5_xfm = load_x5(filename)[0] |
| 276 | + Domain = namedtuple("Domain", "affine shape") |
| 277 | + reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size) |
| 278 | + return cls(x5_xfm.transform, is_deltas=True, reference=reference) |
| 279 | + |
| 280 | + return cls(_factory[fmt.lower()].from_filename(filename)) |
243 | 281 |
|
244 | 282 |
|
245 | 283 | load = DenseFieldTransform.from_filename
|
@@ -295,6 +333,39 @@ def to_field(self, reference=None, dtype="float32"):
|
295 | 333 | field.astype(dtype).reshape(*_ref.shape, -1), reference=_ref
|
296 | 334 | )
|
297 | 335 |
|
| 336 | + def to_x5(self, metadata=None): |
| 337 | + """Return an :class:`~nitransforms.io.x5.X5Transform` representation.""" |
| 338 | + from ._version import __version__ |
| 339 | + from .io.x5 import X5Transform, X5Domain |
| 340 | + |
| 341 | + metadata = {"WrittenBy": f"NiTransforms {__version__}"} | (metadata or {}) |
| 342 | + |
| 343 | + domain = None |
| 344 | + if (reference := self.reference) is not None: |
| 345 | + domain = X5Domain( |
| 346 | + grid=True, |
| 347 | + size=getattr(reference, "shape", (0, 0, 0)), |
| 348 | + mapping=reference.affine, |
| 349 | + coordinates="cartesian", |
| 350 | + ) |
| 351 | + |
| 352 | + meta = metadata | { |
| 353 | + "KnotsAffine": self._knots.affine.tolist(), |
| 354 | + "KnotsShape": self._knots.shape, |
| 355 | + } |
| 356 | + |
| 357 | + kinds = tuple("space" for _ in range(self.ndim)) + ("vector",) |
| 358 | + |
| 359 | + return X5Transform( |
| 360 | + type="nonlinear", |
| 361 | + subtype="bspline", |
| 362 | + representation="coefficients", |
| 363 | + metadata=meta, |
| 364 | + transform=self._coeffs, |
| 365 | + dimension_kinds=kinds, |
| 366 | + domain=domain, |
| 367 | + ) |
| 368 | + |
298 | 369 | def map(self, x, inverse=False):
|
299 | 370 | r"""
|
300 | 371 | Apply the transformation to a list of physical coordinate points.
|
|
0 commit comments