|
15 | 15 | roc, |
16 | 16 | multiclass_roc, |
17 | 17 | multiclass_precision_recall_curve, |
18 | | - dice_score |
| 18 | + dice_score, |
| 19 | + iou, |
19 | 20 | ) |
20 | 21 | from pytorch_lightning.metrics.metric import TensorMetric, TensorCollectionMetric |
21 | 22 |
|
@@ -770,3 +771,48 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: |
770 | 771 | nan_score=self.nan_score, |
771 | 772 | no_fg_score=self.no_fg_score, |
772 | 773 | reduction=self.reduction) |
| 774 | + |
| 775 | + |
| 776 | +class IoU(TensorMetric): |
| 777 | + """ |
| 778 | + Computes the intersection over union. |
| 779 | +
|
| 780 | + Example: |
| 781 | +
|
| 782 | + >>> pred = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0], |
| 783 | + ... [0, 0, 1, 1, 1, 0, 0, 0], |
| 784 | + ... [0, 0, 0, 0, 0, 0, 0, 0]]) |
| 785 | + >>> target = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0], |
| 786 | + ... [0, 0, 0, 1, 1, 1, 0, 0], |
| 787 | + ... [0, 0, 0, 0, 0, 0, 0, 0]]) |
| 788 | + >>> metric = IoU() |
| 789 | + >>> metric(pred, target) |
| 790 | + tensor(0.7045) |
| 791 | +
|
| 792 | + """ |
| 793 | + def __init__(self, |
| 794 | + remove_bg: bool = False, |
| 795 | + reduction: str = 'elementwise_mean'): |
| 796 | + """ |
| 797 | + Args: |
| 798 | + remove_bg: Flag to state whether a background class has been included |
| 799 | + within input parameters. If true, will remove background class. If |
| 800 | + false, return IoU over all classes. |
| 801 | + Assumes that background is '0' class in input tensor |
| 802 | + reduction: a method for reducing IoU over labels (default: takes the mean) |
| 803 | + Available reduction methods: |
| 804 | +
|
| 805 | + - elementwise_mean: takes the mean |
| 806 | + - none: pass array |
| 807 | + - sum: add elements |
| 808 | + """ |
| 809 | + super().__init__(name='iou') |
| 810 | + self.remove_bg = remove_bg |
| 811 | + self.reduction = reduction |
| 812 | + |
| 813 | + def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor, |
| 814 | + sample_weight: Optional[torch.Tensor] = None): |
| 815 | + """ |
| 816 | + Actual metric calculation. |
| 817 | + """ |
| 818 | + return iou(y_pred, y_true, remove_bg=self.remove_bg, reduction=self.reduction) |
0 commit comments