@@ -235,38 +235,7 @@ def equals(A, B, reraise=False):
235235 for field in sorted (field_set ):
236236 x1 , x2 = getattr (A , field ), getattr (B , field )
237237 err_msg = f"{ field } on { A } did not equal { field } on { B } (\n { x1 } vs\n { x2 } \n )"
238-
239- if isinstance (x1 , pd .DataFrame ):
240- try :
241- assert_frame_equal (x1 , x2 , check_like = True )
242- except :
243- print (err_msg )
244- raise
245- elif isinstance (x1 , np .ndarray ):
246- np .testing .assert_array_almost_equal (x1 , x2 , err_msg = err_msg )
247- elif isinstance (x1 , xr .DataArray ):
248- xr .testing .assert_allclose (x1 , x2 )
249- elif isinstance (x1 , (list ,)):
250- assert x1 == x2 , err_msg
251- elif isinstance (x1 , (sitk .Image ,)):
252- assert x1 .GetSize () == x2 .GetSize (), err_msg
253- assert x1 == x2 , err_msg
254- elif isinstance (x1 , (dict ,)):
255- for key in set (x1 .keys ()).union (set (x2 .keys ())):
256- key_err_msg = f"{ key } on { field } on { A } did not equal { key } on { field } on { B } "
257-
258- if isinstance (x1 [key ], (np .ndarray ,)):
259- np .testing .assert_array_almost_equal (x1 [key ], x2 [key ], err_msg = key_err_msg )
260- elif isinstance (x1 [key ], (float ,)):
261- if math .isnan (x1 [key ]) or math .isnan (x2 [key ]):
262- assert math .isnan (x1 [key ]) and math .isnan (x2 [key ]), key_err_msg
263- else :
264- assert x1 [key ] == x2 [key ], key_err_msg
265- else :
266- assert x1 [key ] == x2 [key ], key_err_msg
267-
268- else :
269- assert x1 == x2 , err_msg
238+ compare_fields (x1 , x2 , err_msg )
270239
271240 except NotImplementedError as e :
272241 A_implements_get_field = hasattr (A .api , getattr (type (A ), field ).getter_name )
@@ -279,3 +248,38 @@ def equals(A, B, reraise=False):
279248 return False
280249
281250 return True
251+
252+
253+
254+ def compare_fields (x1 , x2 , err_msg = "" ):
255+ if isinstance (x1 , pd .DataFrame ):
256+ try :
257+ assert_frame_equal (x1 , x2 , check_like = True )
258+ except :
259+ print (err_msg )
260+ raise
261+ elif isinstance (x1 , np .ndarray ):
262+ np .testing .assert_array_almost_equal (x1 , x2 , err_msg = err_msg )
263+ elif isinstance (x1 , xr .DataArray ):
264+ xr .testing .assert_allclose (x1 , x2 )
265+ elif isinstance (x1 , (list ,)):
266+ assert x1 == x2 , err_msg
267+ elif isinstance (x1 , (sitk .Image ,)):
268+ assert x1 .GetSize () == x2 .GetSize (), err_msg
269+ assert x1 == x2 , err_msg
270+ elif isinstance (x1 , (dict ,)):
271+ for key in set (x1 .keys ()).union (set (x2 .keys ())):
272+ key_err_msg = f"mismatch when checking key { key } . { err_msg } "
273+
274+ if isinstance (x1 [key ], (np .ndarray ,)):
275+ np .testing .assert_array_almost_equal (x1 [key ], x2 [key ], err_msg = key_err_msg )
276+ elif isinstance (x1 [key ], (float ,)):
277+ if math .isnan (x1 [key ]) or math .isnan (x2 [key ]):
278+ assert math .isnan (x1 [key ]) and math .isnan (x2 [key ]), key_err_msg
279+ else :
280+ assert x1 [key ] == x2 [key ], key_err_msg
281+ else :
282+ assert x1 [key ] == x2 [key ], key_err_msg
283+
284+ else :
285+ assert x1 == x2 , err_msg
0 commit comments