Skip to content

Commit 08e19b4

Browse files
committed
TEST: klass sometimes missing, equality sometimes undefined
1 parent a537b9f commit 08e19b4

File tree

1 file changed

+25
-6
lines changed

1 file changed

+25
-6
lines changed

nibabel/tests/test_image_api.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -506,36 +506,55 @@ def validate_to_bytes(self, imaker, params):
506506

507507
def validate_from_bytes(self, imaker, params):
508508
img = imaker()
509+
klass = getattr(self, 'klass', img.__class__)
509510
with InTemporaryDirectory():
510511
fname = 'img' + self.standard_extension
511512
img.to_filename(fname)
512513

513514
all_images = list(getattr(self, 'example_images', [])) + [{'fname': fname}]
514515
for img_params in all_images:
515-
img_a = self.klass.from_filename(img_params['fname'])
516+
img_a = klass.from_filename(img_params['fname'])
516517
with open(img_params['fname'], 'rb') as fobj:
517-
img_b = self.klass.from_bytes(fobj.read())
518+
img_b = klass.from_bytes(fobj.read())
518519

519-
assert img_a.header == img_b.header
520+
assert self._header_eq(img_a.header, img_b.header)
520521
assert np.array_equal(img_a.get_data(), img_b.get_data())
521522

522523
def validate_to_from_bytes(self, imaker, params):
523524
img = imaker()
525+
klass = getattr(self, 'klass', img.__class__)
524526
with InTemporaryDirectory():
525527
fname = 'img' + self.standard_extension
526528
img.to_filename(fname)
527529

528530
all_images = list(getattr(self, 'example_images', [])) + [{'fname': fname}]
529531
for img_params in all_images:
530-
img_a = self.klass.from_filename(img_params['fname'])
532+
img_a = klass.from_filename(img_params['fname'])
531533
bytes_a = img_a.to_bytes()
532534

533-
img_b = self.klass.from_bytes(bytes_a)
535+
img_b = klass.from_bytes(bytes_a)
534536

535537
assert img_b.to_bytes() == bytes_a
536-
assert img_a.header == img_b.header
538+
assert self._header_eq(img_a.header, img_b.header)
537539
assert np.array_equal(img_a.get_data(), img_b.get_data())
538540

541+
@staticmethod
542+
def _header_eq(header_a, header_b):
543+
""" Quick-and-dirty header equality check
544+
545+
Abstract classes may have undefined equality, in which case test for
546+
same type
547+
"""
548+
not_implemented = False
549+
header_eq = True
550+
try:
551+
header_eq = header_a == header_b
552+
except NotImplementedError:
553+
header_eq = type(header_a) == type(header_b)
554+
555+
return header_eq
556+
557+
539558

540559
class LoadImageAPI(GenericImageAPI,
541560
DataInterfaceMixin,

0 commit comments

Comments
 (0)