1+ # -*- coding: utf-8 -*-
2+ from __future__ import print_function , division
3+
4+ import torch
5+ import torch .nn as nn
6+ from pymic .loss .util import reshape_tensor_to_2D , get_classwise_dice
7+
8+ class DiceLoss (nn .Module ):
9+ def __init__ (self , params ):
10+ super (DiceLoss , self ).__init__ ()
11+ self .enable_pix_weight = params ['DiceLoss_Enable_Pixel_Weight' .lower ()]
12+ self .enable_cls_weight = params ['DiceLoss_Enable_Class_Weight' .lower ()]
13+
14+ def forward (self , loss_input_dict ):
15+ predict = loss_input_dict ['prediction' ]
16+ soft_y = loss_input_dict ['ground_truth' ]
17+ pix_w = loss_input_dict ['pixel_weight' ]
18+ cls_w = loss_input_dict ['class_weight' ]
19+ softmax = loss_input_dict ['softmax' ]
20+
21+ if (softmax ):
22+ predict = nn .Softmax (dim = 1 )(predict )
23+ predict = reshape_tensor_to_2D (predict )
24+ soft_y = reshape_tensor_to_2D (soft_y )
25+
26+ if (self .enable_pix_weight ):
27+ if (pix_w is None ):
28+ raise ValueError ("Pixel weight is enabled but not defined" )
29+ pix_w = reshape_tensor_to_2D (pix_w )
30+ dice_score = get_classwise_dice (predict , soft_y , pix_w )
31+ if (self .enable_cls_weight ):
32+ if (cls_w is None ):
33+ raise ValueError ("Class weight is enabled but not defined" )
34+ weighted_dice = dice_score * cls_w
35+ avg_dice = weighted_dice .sum () / cls_w .sum ()
36+ else :
37+ avg_dice = torch .mean (dice_score )
38+ dice_loss = 1.0 - avg_dice
39+ return dice_loss
40+
41+ class MultiScaleDiceLoss (nn .Module ):
42+ def __init__ (self , params ):
43+ super (MultiScaleDiceLoss , self ).__init__ ()
44+ self .enable_pix_weight = params ['MultiScaleDiceLoss_Enable_Pixel_Weight' .lower ()]
45+ self .enable_cls_weight = params ['MultiScaleDiceLoss_Enable_Class_Weight' .lower ()]
46+ self .multi_scale_weight = params ['MultiScaleDiceLoss_Scale_Weight' .lower ()]
47+ dice_params = {'DiceLoss_Enable_Pixel_Weight' .lower (): self .enable_pix_weight ,
48+ 'DiceLoss_Enable_Class_Weight' .lower (): self .enable_cls_weight }
49+ self .base_loss = DiceLoss (dice_params )
50+
51+ def forward (self , loss_input_dict ):
52+ predict = loss_input_dict ['prediction' ]
53+ soft_y = loss_input_dict ['ground_truth' ]
54+ pix_w = loss_input_dict ['pixel_weight' ]
55+ cls_w = loss_input_dict ['class_weight' ]
56+ softmax = loss_input_dict ['softmax' ]
57+ if (isinstance (predict , tuple ) or isinstance (predict , list )):
58+ predict_num = len (predict )
59+ assert (predict_num == len (self .multi_scale_weight ))
60+ loss = 0.0
61+ weight = 0.0
62+ interp_mode = 'trilinear' if (len (predict [0 ].shape ) == 5 ) else 'bilinear'
63+ for i in range (predict_num ):
64+ soft_y_temp = nn .functional .interpolate (soft_y ,
65+ size = list (predict [i ].shape )[2 :], mode = interp_mode )
66+ if (pix_w is not None ):
67+ pix_w_temp = nn .functional .interpolate (pix_w ,
68+ size = list (predict [i ].shape )[2 :], mode = interp_mode )
69+ else :
70+ pix_w_temp = None
71+ temp_loss_dict = {}
72+ temp_loss_dict ['prediction' ] = predict [i ]
73+ temp_loss_dict ['ground_truth' ] = soft_y_temp
74+ temp_loss_dict ['pixel_weight' ] = pix_w_temp
75+ temp_loss_dict ['class_weight' ] = cls_w
76+ temp_loss_dict ['softmax' ] = softmax
77+ temp_loss = self .base_loss (temp_loss_dict )
78+ loss = loss + temp_loss * self .multi_scale_weight [i ]
79+ weight = weight + self .multi_scale_weight [i ]
80+ loss = loss / weight
81+ else :
82+ loss = self .base_loss (loss_input_dict )
83+ return loss
0 commit comments