@@ -311,8 +311,21 @@ def aggregated_jaccard_index(
311311 return aji
312312
313313
314+ def _absent_inds (true : np .ndarray , pred : np .ndarray , num_classes : int ) -> np .ndarray :
315+ """Get the class indices that are not present in either `true` or `pred`."""
316+ t = np .unique (true )
317+ p = np .unique (pred )
318+ not_pres = np .setdiff1d (np .arange (num_classes ), np .union1d (t , p ))
319+
320+ return not_pres
321+
322+
314323def iou_multiclass (
315- true : np .ndarray , pred : np .ndarray , num_classes : int , eps : float = 1e-8
324+ true : np .ndarray ,
325+ pred : np .ndarray ,
326+ num_classes : int ,
327+ eps : float = 1e-8 ,
328+ clamp_absent : bool = True ,
316329) -> np .ndarray :
317330 """Compute multi-class intersection over union for semantic segmentation masks.
318331
@@ -326,6 +339,9 @@ def iou_multiclass(
326339 Number of classes in the training dataset.
327340 eps : float, default=1e-8:
328341 Epsilon to avoid zero div errors.
342+ clamp_absent : bool, default=True
343+ If a class is not present in either true or pred, the value of that ix
344+ in the result array will be clamped to -1.0.
329345
330346 Returns
331347 -------
@@ -337,11 +353,21 @@ def iou_multiclass(
337353 fp = fp .diagonal ()
338354 fn = fn .diagonal ()
339355
340- return tp / (tp + fp + fn + eps )
356+ iou = tp / (tp + fp + fn + eps )
357+
358+ if clamp_absent :
359+ not_pres = _absent_inds (true , pred , num_classes )
360+ iou [not_pres ] = - 1.0
361+
362+ return iou
341363
342364
343365def accuracy_multiclass (
344- true : np .ndarray , pred : np .ndarray , num_classes : int , eps : float = 1e-8
366+ true : np .ndarray ,
367+ pred : np .ndarray ,
368+ num_classes : int ,
369+ eps : float = 1e-8 ,
370+ clamp_absent : bool = True ,
345371) -> np .ndarray :
346372 """Compute multi-class accuracy for semantic segmentation masks.
347373
@@ -355,6 +381,9 @@ def accuracy_multiclass(
355381 Number of classes in the training dataset.
356382 eps : float, default=1e-8:
357383 Epsilon to avoid zero div errors.
384+ clamp_absent: bool = True
385+ If a class is not present in either true or pred, the value of that ix
386+ in the result array will be clamped to -1.0.
358387
359388 Returns
360389 -------
@@ -367,11 +396,21 @@ def accuracy_multiclass(
367396 fn = fn .diagonal ()
368397 tn = np .prod (true .shape ) - (tp + fn + fp )
369398
370- return (tp + tn ) / (tp + fp + fn + tn + eps )
399+ accuracy = (tp + tn ) / (tp + fp + fn + tn + eps )
400+
401+ if clamp_absent :
402+ not_pres = _absent_inds (true , pred , num_classes )
403+ accuracy [not_pres ] = - 1.0
404+
405+ return accuracy
371406
372407
373408def f1score_multiclass (
374- true : np .ndarray , pred : np .ndarray , num_classes : int , eps : float = 1e-8
409+ true : np .ndarray ,
410+ pred : np .ndarray ,
411+ num_classes : int ,
412+ eps : float = 1e-8 ,
413+ clamp_absent : bool = True ,
375414) -> np .ndarray :
376415 """Compute multi-class f1-score for semantic segmentation masks.
377416
@@ -385,6 +424,9 @@ def f1score_multiclass(
385424 Number of classes in the training dataset.
386425 eps : float, default=1e-8:
387426 Epsilon to avoid zero div errors.
427+ clamp_absent: bool = True
428+ If a class is not present in either true or pred, the value of that ix
429+ in the result array will be clamped to -1.0.
388430
389431 Returns
390432 -------
@@ -396,11 +438,21 @@ def f1score_multiclass(
396438 fp = fp .diagonal ()
397439 fn = fn .diagonal ()
398440
399- return tp / (0.5 * fp + 0.5 * fn + tp + eps )
441+ f1 = tp / (0.5 * fp + 0.5 * fn + tp + eps )
442+
443+ if clamp_absent :
444+ not_pres = _absent_inds (true , pred , num_classes )
445+ f1 [not_pres ] = - 1.0
446+
447+ return f1
400448
401449
402450def dice_multiclass (
403- true : np .ndarray , pred : np .ndarray , num_classes : int , eps : float = 1e-8
451+ true : np .ndarray ,
452+ pred : np .ndarray ,
453+ num_classes : int ,
454+ eps : float = 1e-8 ,
455+ clamp_absent : bool = True ,
404456) -> np .ndarray :
405457 """Compute multi-class dice for semantic segmentation masks.
406458
@@ -414,6 +466,9 @@ def dice_multiclass(
414466 Number of classes in the training dataset.
415467 eps : float, default=1e-8:
416468 Epsilon to avoid zero div errors.
469+ clamp_absent: bool = True
470+ If a class is not present in either true or pred, the value of that ix
471+ in the result array will be clamped to -1.0.
417472
418473 Returns
419474 -------
@@ -425,11 +480,21 @@ def dice_multiclass(
425480 fp = fp .diagonal ()
426481 fn = fn .diagonal ()
427482
428- return 2 * tp / (2 * tp + fp + fn + eps )
483+ dice = 2 * tp / (2 * tp + fp + fn + eps )
484+
485+ if clamp_absent :
486+ not_pres = _absent_inds (true , pred , num_classes )
487+ dice [not_pres ] = - 1.0
488+
489+ return dice
429490
430491
431492def sensitivity_multiclass (
432- true : np .ndarray , pred : np .ndarray , num_classes : int , eps : float = 1e-8
493+ true : np .ndarray ,
494+ pred : np .ndarray ,
495+ num_classes : int ,
496+ eps : float = 1e-8 ,
497+ clamp_absent : bool = True ,
433498) -> np .ndarray :
434499 """Compute multi-class sensitivity for semantic segmentation masks.
435500
@@ -443,6 +508,9 @@ def sensitivity_multiclass(
443508 Number of classes in the training dataset.
444509 eps : float, default=1e-8:
445510 Epsilon to avoid zero div errors.
511+ clamp_absent: bool = True
512+ If a class is not present in either true or pred, the value of that ix
513+ in the result array will be clamped to -1.0.
446514
447515 Returns
448516 -------
@@ -454,11 +522,21 @@ def sensitivity_multiclass(
454522 fp = fp .diagonal ()
455523 fn = fn .diagonal ()
456524
457- return tp / (tp + fn + eps )
525+ sensitivity = tp / (tp + fn + eps )
526+
527+ if clamp_absent :
528+ not_pres = _absent_inds (true , pred , num_classes )
529+ sensitivity [not_pres ] = - 1.0
530+
531+ return sensitivity
458532
459533
460534def specificity_multiclass (
461- true : np .ndarray , pred : np .ndarray , num_classes : int , eps : float = 1e-8
535+ true : np .ndarray ,
536+ pred : np .ndarray ,
537+ num_classes : int ,
538+ eps : float = 1e-8 ,
539+ clamp_absent : bool = True ,
462540) -> np .ndarray :
463541 """Compute multi-class specificity for semantic segmentation masks.
464542
@@ -472,6 +550,9 @@ def specificity_multiclass(
472550 Number of classes in the training dataset.
473551 eps : float, default=1e-8:
474552 Epsilon to avoid zero div errors.
553+ clamp_absent: bool = True
554+ If a class is not present in either true or pred, the value of that ix
555+ in the result array will be clamped to -1.0.
475556
476557 Returns
477558 -------
@@ -483,4 +564,10 @@ def specificity_multiclass(
483564 fp = fp .diagonal ()
484565 fn = fn .diagonal ()
485566
486- return tp / (tp + fp + eps )
567+ specificity = tp / (tp + fp + eps )
568+
569+ if clamp_absent :
570+ not_pres = _absent_inds (true , pred , num_classes )
571+ specificity [not_pres ] = - 1.0
572+
573+ return specificity
0 commit comments