Skip to content

Commit bb8b808

Browse files
committed
RF: Add generic NiftiExtension base class
Nifti1Extension is a non-ideal base class for NIfTI extensions because it assumes that it is safe to store use a null transformation, and thus default to `bytes` objects. This makes it difficult to define its typing behavior in a way that allows subclasses to refine the type such that type-checkers understand it. This patch creates a generic `NiftiExtension` class that parameterizes the "runtime representation" type. Nifti1Extension subclasses with another parameter that defaults to `bytes`, allowing it to be subclassed in turn (preserving the Nifti1Extension -> Nifti1DicomExtension subclass relationship) while still emitting `bytes`. We could have simply made `Nifti1Extension` the base class, but the mangle/unmangle methods need some casts or ignore comments to type-check cleanly. This separation allows us to have a clean base class with the legacy hacks cordoned off into an subclass.
1 parent 7a502a3 commit bb8b808

File tree

2 files changed

+166
-104
lines changed

2 files changed

+166
-104
lines changed

nibabel/nifti1.py

Lines changed: 163 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@
1313

1414
from __future__ import annotations
1515

16+
import typing as ty
1617
import warnings
1718
from io import BytesIO
1819

1920
import numpy as np
2021
import numpy.linalg as npl
22+
from typing_extensions import TypeVar # PY312
2123

2224
from . import analyze # module import
2325
from .arrayproxy import get_obj_dtype
@@ -31,7 +33,19 @@
3133
from .spm99analyze import SpmAnalyzeHeader
3234
from .volumeutils import Recoder, endian_codes, make_dt_codes
3335

34-
pdcm, have_dicom, _ = optional_package('pydicom')
36+
if ty.TYPE_CHECKING:
37+
import pydicom as pdcm
38+
39+
have_dicom = True
40+
DicomDataset = pdcm.Dataset
41+
else:
42+
pdcm, have_dicom, _ = optional_package('pydicom')
43+
if have_dicom:
44+
DicomDataset = pdcm.Dataset
45+
else:
46+
DicomDataset = ty.Any
47+
48+
T = TypeVar('T', default=bytes)
3549

3650
# nifti1 flat header definition for Analyze-like first 348 bytes
3751
# first number in comments indicates offset in file header in bytes
@@ -283,110 +297,103 @@
283297
)
284298

285299

286-
class Nifti1Extension:
287-
"""Baseclass for NIfTI1 header extensions.
300+
class NiftiExtension(ty.Generic[T]):
301+
"""Base class for NIfTI header extensions."""
288302

289-
This class is sufficient to handle very simple text-based extensions, such
290-
as `comment`. More sophisticated extensions should/will be supported by
291-
dedicated subclasses.
292-
"""
303+
code: int
304+
encoding: ty.Optional[str] = None
305+
_content: bytes
306+
_object: ty.Optional[T] = None
293307

294-
def __init__(self, code, content):
308+
def __init__(
309+
self,
310+
code: ty.Union[int, str],
311+
content: bytes,
312+
) -> None:
295313
"""
296314
Parameters
297315
----------
298316
code : int or str
299317
Canonical extension code as defined in the NIfTI standard, given
300318
either as integer or corresponding label
301319
(see :data:`~nibabel.nifti1.extension_codes`)
302-
content : str
303-
Extension content as read from the NIfTI file header. This content is
304-
converted into a runtime representation.
320+
content : bytes
321+
Extension content as read from the NIfTI file header. This content may
322+
be converted into a runtime representation.
305323
"""
306324
try:
307-
self._code = extension_codes.code[code]
325+
self.code = extension_codes.code[code] # type: ignore[assignment]
308326
except KeyError:
309-
# XXX or fail or at least complain?
310-
self._code = code
311-
self._content = self._unmangle(content)
327+
self.code = code # type: ignore[assignment]
328+
self._content = content
312329

313-
def _unmangle(self, value):
314-
"""Convert the extension content into its runtime representation.
330+
# Handle (de)serialization of extension content
331+
# Subclasses may implement these methods to provide an alternative
332+
# view of the extension content. If left unimplemented, the content
333+
# must be bytes and is not modified.
334+
def _mangle(self, obj: T) -> bytes:
335+
raise NotImplementedError
315336

316-
The default implementation does nothing at all.
337+
def _unmangle(self, content: bytes) -> T:
338+
raise NotImplementedError
317339

318-
Parameters
319-
----------
320-
value : str
321-
Extension content as read from file.
340+
def _sync(self) -> None:
341+
"""Synchronize content with object.
322342
323-
Returns
324-
-------
325-
The same object that was passed as `value`.
326-
327-
Notes
328-
-----
329-
Subclasses should reimplement this method to provide the desired
330-
unmangling procedure and may return any type of object.
343+
This permits the runtime representation to be modified in-place
344+
and updates the bytes representation accordingly.
331345
"""
332-
return value
333-
334-
def _mangle(self, value):
335-
"""Convert the extension content into NIfTI file header representation.
346+
if self._object is not None:
347+
self._content = self._mangle(self._object)
336348

337-
The default implementation does nothing at all.
338-
339-
Parameters
340-
----------
341-
value : str
342-
Extension content in runtime form.
349+
def __repr__(self) -> str:
350+
try:
351+
code = extension_codes.label[self.code]
352+
except KeyError:
353+
# deal with unknown codes
354+
code = self.code
355+
return f'{self.__class__.__name__}({code}, {self._content!r})'
343356

344-
Returns
345-
-------
346-
str
357+
def __eq__(self, other: object) -> bool:
358+
return (
359+
isinstance(other, self.__class__)
360+
and self.code == other.code
361+
and self.content == other.content
362+
)
347363

348-
Notes
349-
-----
350-
Subclasses should reimplement this method to provide the desired
351-
mangling procedure.
352-
"""
353-
return value
364+
def __ne__(self, other):
365+
return not self == other
354366

355367
def get_code(self):
356368
"""Return the canonical extension type code."""
357-
return self._code
369+
return self.code
358370

359-
def get_content(self):
360-
"""Return the extension content in its runtime representation."""
371+
@property
372+
def content(self) -> bytes:
373+
"""Return the extension content as raw bytes."""
374+
self._sync()
361375
return self._content
362376

363-
def get_sizeondisk(self):
377+
def get_content(self) -> T:
378+
"""Return the extension content in its runtime representation.
379+
380+
This method may return a different type for each extension type.
381+
"""
382+
if self._object is None:
383+
self._object = self._unmangle(self._content)
384+
return self._object
385+
386+
def get_sizeondisk(self) -> int:
364387
"""Return the size of the extension in the NIfTI file."""
388+
self._sync()
365389
# need raw value size plus 8 bytes for esize and ecode
366-
size = len(self._mangle(self._content))
367-
size += 8
390+
size = len(self._content) + 8
368391
# extensions size has to be a multiple of 16 bytes
369392
if size % 16 != 0:
370393
size += 16 - (size % 16)
371394
return size
372395

373-
def __repr__(self):
374-
try:
375-
code = extension_codes.label[self._code]
376-
except KeyError:
377-
# deal with unknown codes
378-
code = self._code
379-
380-
s = f"Nifti1Extension('{code}', '{self._content}')"
381-
return s
382-
383-
def __eq__(self, other):
384-
return (self._code, self._content) == (other._code, other._content)
385-
386-
def __ne__(self, other):
387-
return not self == other
388-
389-
def write_to(self, fileobj, byteswap):
396+
def write_to(self, fileobj: ty.BinaryIO, byteswap: bool = False) -> None:
390397
"""Write header extensions to fileobj
391398
392399
Write starts at fileobj current file position.
@@ -402,22 +409,74 @@ def write_to(self, fileobj, byteswap):
402409
-------
403410
None
404411
"""
412+
self._sync()
405413
extstart = fileobj.tell()
406414
rawsize = self.get_sizeondisk()
407415
# write esize and ecode first
408-
extinfo = np.array((rawsize, self._code), dtype=np.int32)
416+
extinfo = np.array((rawsize, self.code), dtype=np.int32)
409417
if byteswap:
410418
extinfo = extinfo.byteswap()
411419
fileobj.write(extinfo.tobytes())
412420
# followed by the actual extension content
413421
# XXX if mangling upon load is implemented, it should be reverted here
414-
fileobj.write(self._mangle(self._content))
422+
fileobj.write(self._content)
415423
# be nice and zero out remaining part of the extension till the
416424
# next 16 byte border
417425
fileobj.write(b'\x00' * (extstart + rawsize - fileobj.tell()))
418426

419427

420-
class Nifti1DicomExtension(Nifti1Extension):
428+
class Nifti1Extension(NiftiExtension[T]):
429+
"""Baseclass for NIfTI1 header extensions.
430+
431+
This class is sufficient to handle very simple text-based extensions, such
432+
as `comment`. More sophisticated extensions should/will be supported by
433+
dedicated subclasses.
434+
"""
435+
436+
def _unmangle(self, value: bytes) -> T:
437+
"""Convert the extension content into its runtime representation.
438+
439+
The default implementation does nothing at all.
440+
441+
Parameters
442+
----------
443+
value : str
444+
Extension content as read from file.
445+
446+
Returns
447+
-------
448+
The same object that was passed as `value`.
449+
450+
Notes
451+
-----
452+
Subclasses should reimplement this method to provide the desired
453+
unmangling procedure and may return any type of object.
454+
"""
455+
return value # type: ignore[return-value]
456+
457+
def _mangle(self, value: T) -> bytes:
458+
"""Convert the extension content into NIfTI file header representation.
459+
460+
The default implementation does nothing at all.
461+
462+
Parameters
463+
----------
464+
value : str
465+
Extension content in runtime form.
466+
467+
Returns
468+
-------
469+
str
470+
471+
Notes
472+
-----
473+
Subclasses should reimplement this method to provide the desired
474+
mangling procedure.
475+
"""
476+
return value # type: ignore[return-value]
477+
478+
479+
class Nifti1DicomExtension(Nifti1Extension[DicomDataset]):
421480
"""NIfTI1 DICOM header extension
422481
423482
This class is a thin wrapper around pydicom to read a binary DICOM
@@ -427,7 +486,12 @@ class Nifti1DicomExtension(Nifti1Extension):
427486
header.
428487
"""
429488

430-
def __init__(self, code, content, parent_hdr=None):
489+
def __init__(
490+
self,
491+
code: ty.Union[int, str],
492+
content: ty.Union[bytes, DicomDataset, None] = None,
493+
parent_hdr: ty.Optional[Nifti1Header] = None,
494+
) -> None:
431495
"""
432496
Parameters
433497
----------
@@ -452,50 +516,48 @@ def __init__(self, code, content, parent_hdr=None):
452516
code should always be 2 for DICOM.
453517
"""
454518

455-
self._code = code
456-
if parent_hdr:
457-
self._is_little_endian = parent_hdr.endianness == '<'
458-
else:
459-
self._is_little_endian = True
519+
self._is_little_endian = parent_hdr is None or parent_hdr.endianness == '<'
520+
521+
bytes_content: bytes
460522
if isinstance(content, pdcm.dataset.Dataset):
461523
self._is_implicit_VR = False
462-
self._raw_content = self._mangle(content)
463-
self._content = content
524+
self._object = content
525+
bytes_content = self._mangle(content)
464526
elif isinstance(content, bytes): # Got a byte string - unmangle it
465-
self._raw_content = content
466-
self._is_implicit_VR = self._guess_implicit_VR()
467-
ds = self._unmangle(content, self._is_implicit_VR, self._is_little_endian)
468-
self._content = ds
527+
self._is_implicit_VR = self._guess_implicit_VR(content)
528+
self._object = self._unmangle(content)
529+
bytes_content = content
469530
elif content is None: # initialize a new dicom dataset
470531
self._is_implicit_VR = False
471-
self._content = pdcm.dataset.Dataset()
532+
self._object = pdcm.dataset.Dataset()
533+
bytes_content = self._mangle(self._object)
472534
else:
473535
raise TypeError(
474536
f'content must be either a bytestring or a pydicom Dataset. '
475537
f'Got {content.__class__}'
476538
)
539+
super().__init__(code, bytes_content)
477540

478-
def _guess_implicit_VR(self):
541+
@staticmethod
542+
def _guess_implicit_VR(content) -> bool:
479543
"""Try to guess DICOM syntax by checking for valid VRs.
480544
481545
Without a DICOM Transfer Syntax, it's difficult to tell if Value
482546
Representations (VRs) are included in the DICOM encoding or not.
483547
This reads where the first VR would be and checks it against a list of
484548
valid VRs
485549
"""
486-
potential_vr = self._raw_content[4:6].decode()
487-
if potential_vr in pdcm.values.converters.keys():
488-
implicit_VR = False
489-
else:
490-
implicit_VR = True
491-
return implicit_VR
492-
493-
def _unmangle(self, value, is_implicit_VR=False, is_little_endian=True):
494-
bio = BytesIO(value)
495-
ds = pdcm.filereader.read_dataset(bio, is_implicit_VR, is_little_endian)
496-
return ds
550+
potential_vr = content[4:6].decode()
551+
return potential_vr not in pdcm.values.converters.keys()
552+
553+
def _unmangle(self, obj: bytes) -> DicomDataset:
554+
return pdcm.filereader.read_dataset(
555+
BytesIO(obj),
556+
self._is_implicit_VR,
557+
self._is_little_endian,
558+
)
497559

498-
def _mangle(self, dataset):
560+
def _mangle(self, dataset: DicomDataset) -> bytes:
499561
bio = BytesIO()
500562
dio = pdcm.filebase.DicomFileLike(bio)
501563
dio.is_implicit_VR = self._is_implicit_VR

0 commit comments

Comments
 (0)