Skip to content

Commit 7237eba

Browse files
committed
rf: Allow extensions to be constructed from objects without serialization
1 parent 8b0e699 commit 7237eba

File tree

1 file changed

+58
-19
lines changed

1 file changed

+58
-19
lines changed

nibabel/nifti1.py

Lines changed: 58 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,25 @@
299299

300300

301301
class NiftiExtension(ty.Generic[T]):
302-
"""Base class for NIfTI header extensions."""
302+
"""Base class for NIfTI header extensions.
303+
304+
This class provides access to the extension content in various forms.
305+
For simple extensions that expose data as bytes, text or JSON, this class
306+
is sufficient. More complex extensions should be implemented as subclasses
307+
that provide custom serialization/deserialization methods.
308+
309+
Efficiency note:
310+
311+
This class assumes that the runtime representation of the extension content
312+
is mutable. Once a runtime representation is set, it is cached and will be
313+
serialized on any attempt to access the extension content as bytes, including
314+
determining the size of the extension in the NIfTI file.
315+
316+
If the runtime representation is never accessed, the raw bytes will be used
317+
without modification. While avoiding unnecessary deserialization, if there
318+
are bytestrings that do not produce a valid runtime representation, they will
319+
be written as-is, and may cause errors downstream.
320+
"""
303321

304322
code: int
305323
encoding: ty.Optional[str] = None
@@ -309,7 +327,8 @@ class NiftiExtension(ty.Generic[T]):
309327
def __init__(
310328
self,
311329
code: ty.Union[int, str],
312-
content: bytes,
330+
content: bytes = b'',
331+
object: ty.Optional[T] = None,
313332
) -> None:
314333
"""
315334
Parameters
@@ -318,21 +337,40 @@ def __init__(
318337
Canonical extension code as defined in the NIfTI standard, given
319338
either as integer or corresponding label
320339
(see :data:`~nibabel.nifti1.extension_codes`)
321-
content : bytes
322-
Extension content as read from the NIfTI file header. This content may
323-
be converted into a runtime representation.
340+
content : bytes, optional
341+
Extension content as read from the NIfTI file header.
342+
object : optional
343+
Extension content in runtime form.
324344
"""
325345
try:
326346
self.code = extension_codes.code[code] # type: ignore[assignment]
327347
except KeyError:
328348
self.code = code # type: ignore[assignment]
329349
self._content = content
350+
if object is not None:
351+
self._object = object
330352

331353
@classmethod
332354
def from_bytes(cls, content: bytes) -> Self:
355+
"""Create an extension from raw bytes.
356+
357+
This constructor may only be used in extension classes with a class
358+
attribute `code` to indicate the extension type.
359+
"""
333360
if not hasattr(cls, 'code'):
334361
raise NotImplementedError('from_bytes() requires a class attribute `code`')
335-
return cls(cls.code, content)
362+
return cls(cls.code, content=content)
363+
364+
@classmethod
365+
def from_object(cls, obj: T) -> Self:
366+
"""Create an extension from a runtime object.
367+
368+
This constructor may only be used in extension classes with a class
369+
attribute `code` to indicate the extension type.
370+
"""
371+
if not hasattr(cls, 'code'):
372+
raise NotImplementedError('from_object() requires a class attribute `code`')
373+
return cls(cls.code, object=obj)
336374

337375
# Handle (de)serialization of extension content
338376
# Subclasses may implement these methods to provide an alternative
@@ -401,7 +439,7 @@ def json(self) -> ty.Any:
401439
"""
402440
return json.loads(self.content)
403441

404-
def get_content(self) -> T:
442+
def get_object(self) -> T:
405443
"""Return the extension content in its runtime representation.
406444
407445
This method may return a different type for each extension type.
@@ -412,15 +450,14 @@ def get_content(self) -> T:
412450
self._object = self._unmangle(self._content)
413451
return self._object
414452

453+
# Backwards compatibility
454+
get_content = get_object
455+
415456
def get_sizeondisk(self) -> int:
416457
"""Return the size of the extension in the NIfTI file."""
417-
self._sync()
418-
# need raw value size plus 8 bytes for esize and ecode
419-
size = len(self._content) + 8
420-
# extensions size has to be a multiple of 16 bytes
421-
if size % 16 != 0:
422-
size += 16 - (size % 16)
423-
return size
458+
# need raw value size plus 8 bytes for esize and ecode, rounded up to next 16 bytes
459+
# Rounding C+8 up to M is done by (C+8 + (M-1)) // M * M
460+
return (len(self.content) + 23) // 16 * 16
424461

425462
def write_to(self, fileobj: ty.BinaryIO, byteswap: bool = False) -> None:
426463
"""Write header extensions to fileobj
@@ -438,20 +475,20 @@ def write_to(self, fileobj: ty.BinaryIO, byteswap: bool = False) -> None:
438475
-------
439476
None
440477
"""
441-
self._sync()
442478
extstart = fileobj.tell()
443-
rawsize = self.get_sizeondisk()
479+
rawsize = self.get_sizeondisk() # Calls _sync()
444480
# write esize and ecode first
445481
extinfo = np.array((rawsize, self.code), dtype=np.int32)
446482
if byteswap:
447483
extinfo = extinfo.byteswap()
448484
fileobj.write(extinfo.tobytes())
449-
# followed by the actual extension content
450-
# XXX if mangling upon load is implemented, it should be reverted here
485+
# followed by the actual extension content, synced above
451486
fileobj.write(self._content)
452487
# be nice and zero out remaining part of the extension till the
453488
# next 16 byte border
454-
fileobj.write(b'\x00' * (extstart + rawsize - fileobj.tell()))
489+
pad = extstart + rawsize - fileobj.tell()
490+
if pad:
491+
fileobj.write(bytes(pad))
455492

456493

457494
class Nifti1Extension(NiftiExtension[T]):
@@ -462,6 +499,8 @@ class Nifti1Extension(NiftiExtension[T]):
462499
dedicated subclasses.
463500
"""
464501

502+
code = 0 # Default to unknown extension
503+
465504
def _unmangle(self, value: bytes) -> T:
466505
"""Convert the extension content into its runtime representation.
467506

0 commit comments

Comments
 (0)