Skip to content

Commit 7a98e16

Browse files
committed
TEST: klass sometimes missing, equality sometimes undefined
1 parent 6acf044 commit 7a98e16

File tree

1 file changed

+25
-6
lines changed

1 file changed

+25
-6
lines changed

nibabel/tests/test_image_api.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -464,36 +464,55 @@ def validate_to_bytes(self, imaker, params):
464464

465465
def validate_from_bytes(self, imaker, params):
466466
img = imaker()
467+
klass = getattr(self, 'klass', img.__class__)
467468
with InTemporaryDirectory():
468469
fname = 'img' + self.standard_extension
469470
img.to_filename(fname)
470471

471472
all_images = list(getattr(self, 'example_images', [])) + [{'fname': fname}]
472473
for img_params in all_images:
473-
img_a = self.klass.from_filename(img_params['fname'])
474+
img_a = klass.from_filename(img_params['fname'])
474475
with open(img_params['fname'], 'rb') as fobj:
475-
img_b = self.klass.from_bytes(fobj.read())
476+
img_b = klass.from_bytes(fobj.read())
476477

477-
assert img_a.header == img_b.header
478+
assert self._header_eq(img_a.header, img_b.header)
478479
assert np.array_equal(img_a.get_data(), img_b.get_data())
479480

480481
def validate_to_from_bytes(self, imaker, params):
481482
img = imaker()
483+
klass = getattr(self, 'klass', img.__class__)
482484
with InTemporaryDirectory():
483485
fname = 'img' + self.standard_extension
484486
img.to_filename(fname)
485487

486488
all_images = list(getattr(self, 'example_images', [])) + [{'fname': fname}]
487489
for img_params in all_images:
488-
img_a = self.klass.from_filename(img_params['fname'])
490+
img_a = klass.from_filename(img_params['fname'])
489491
bytes_a = img_a.to_bytes()
490492

491-
img_b = self.klass.from_bytes(bytes_a)
493+
img_b = klass.from_bytes(bytes_a)
492494

493495
assert img_b.to_bytes() == bytes_a
494-
assert img_a.header == img_b.header
496+
assert self._header_eq(img_a.header, img_b.header)
495497
assert np.array_equal(img_a.get_data(), img_b.get_data())
496498

499+
@staticmethod
500+
def _header_eq(header_a, header_b):
501+
""" Quick-and-dirty header equality check
502+
503+
Abstract classes may have undefined equality, in which case test for
504+
same type
505+
"""
506+
not_implemented = False
507+
header_eq = True
508+
try:
509+
header_eq = header_a == header_b
510+
except NotImplementedError:
511+
header_eq = type(header_a) == type(header_b)
512+
513+
return header_eq
514+
515+
497516

498517
class LoadImageAPI(GenericImageAPI,
499518
DataInterfaceMixin,

0 commit comments

Comments
 (0)