@@ -132,7 +132,11 @@ def computeQualityMeasures(lP: np.ndarray,
132132
133133 if set (distance_metrics ).intersection (metrics_names ) or not metrics_names :
134134 # Surface distance measures
135- signed_distance_map = sitk .SignedMaurerDistanceMap (labelTrue > 0.5 , squaredDistance = False ,
135+ if np .sum (lT ) == 0 : # all 0, set the distance map to 0. Otherwise, SignedMaurerDistanceMap may raise exceptions.
136+ signed_distance_map = sitk .GetImageFromArray (lT , isVector = False )
137+ signed_distance_map .SetSpacing (spacing )
138+ else :
139+ signed_distance_map = sitk .SignedMaurerDistanceMap (labelTrue > 0.5 , squaredDistance = False ,
136140 useImageSpacing = True ) # It need to be adapted.
137141
138142 ref_distance_map = sitk .Abs (signed_distance_map )
@@ -145,17 +149,22 @@ def computeQualityMeasures(lP: np.ndarray,
145149
146150 num_ref_surface_pixels = int (statistics_image_filter .GetSum ())
147151
148- signed_distance_map_pred = sitk .SignedMaurerDistanceMap (labelPred > 0.5 , squaredDistance = False ,
149- useImageSpacing = True )
152+
153+ if np .sum (lP ) == 0 : # all 0, set the distance map to 0. Otherwise, SignedMaurerDistanceMap may raise exceptions.
154+ signed_distance_map_pred = sitk .GetImageFromArray (lP , isVector = False )
155+ signed_distance_map_pred .SetSpacing (spacing )
156+ else :
157+ signed_distance_map_pred = sitk .SignedMaurerDistanceMap (labelPred > 0.5 , squaredDistance = False ,
158+ useImageSpacing = True )
150159
151160 seg_distance_map = sitk .Abs (signed_distance_map_pred )
152161
153162 seg_surface = sitk .LabelContour (labelPred > 0.5 , fullyConnected = fullyConnected )
154163 seg_surface_array = sitk .GetArrayViewFromImage (seg_surface )
155164
156- seg2ref_distance_map = ref_distance_map * sitk .Cast (seg_surface , sitk .sitkFloat32 )
165+ seg2ref_distance_map = sitk . Cast ( ref_distance_map , sitk . sitkFloat32 ) * sitk .Cast (seg_surface , sitk .sitkFloat32 )
157166
158- ref2seg_distance_map = seg_distance_map * sitk .Cast (ref_surface , sitk .sitkFloat32 )
167+ ref2seg_distance_map = sitk . Cast ( seg_distance_map , sitk . sitkFloat32 ) * sitk .Cast (ref_surface , sitk .sitkFloat32 )
159168
160169 statistics_image_filter .Execute (seg_surface > 0.5 )
161170
@@ -414,12 +423,20 @@ def main():
414423 gdth_path = 'data/gdth'
415424 pred_path = 'data/pred'
416425 csv_file = 'metrics.csv'
417-
418- write_metrics (labels = labels [1 :], # exclude background
419- gdth_path = gdth_path ,
420- pred_path = pred_path ,
421- csv_file = csv_file )
422-
426+
427+ labels = [0 , 1 , 2 , 3 , 4 ]
428+ gdth_img = np .array ([[0 ,0 ,0 ],
429+ [0 ,2 ,1 ]])
430+ pred_img = np .array ([[0 ,0 ,0 ],
431+ [0 ,3 ,1 ]])
432+ csv_file = 'metrics.csv'
433+ spacing = [1 , 1 ]
434+ metrics = write_metrics (labels = labels [1 :], # exclude background if needed
435+ gdth_img = gdth_img ,
436+ pred_img = pred_img ,
437+ csv_file = csv_file ,
438+ spacing = spacing )
439+ print (metrics )
423440
424441if __name__ == "__main__" :
425442 main ()
0 commit comments