99
1010
1111def safe_divide (numerator : Tensor , denominator : Tensor , epsilon : float = 1e-8 ) -> Tensor :
12- """Safely divide two tensors, avoiding division by zero."""
1312 return tf .math .divide (
1413 numerator , tf .clip_by_value (denominator , epsilon , tf .reduce_max (denominator ))
1514 )
1615
1716
1817def safe_pow (x : Tensor , p : Tensor , epsilon : float = 1e-8 ) -> Tensor :
19- """Compute x^p safely by ensuring x is within a valid range."""
2018 return tf .pow (tf .clip_by_value (x , epsilon , 1.0 - epsilon ), p )
2119
2220
23- class TcgeScalar (Loss ):
21+ def reliability_penalizer (
22+ lms : Tensor , lambdas : Tensor , a : float , b : float , c : float
23+ ) -> Tensor :
24+ x = lambdas - lms
25+ return c * tf .maximum (1 / (1 - a ) * x * tf .exp ((x - 1 ) / b ), 0 )
26+
27+
28+ class TgceScalar (Loss ):
2429 """
2530 Truncated generalized cross entropy
2631 for semantic segmentation loss.
@@ -35,32 +40,40 @@ def __init__( # pylint: disable=too-many-arguments
3540 noise_tolerance : float = 0.1 ,
3641 a : float = 0.7 ,
3742 b : float = 0.7 ,
43+ c : float = 1.0 ,
44+ lambda_reg_weight : float = 0.1 ,
45+ lambda_entropy_weight : float = 0.1 ,
46+ lambda_sum_weight : float = 0.1 ,
3847 epsilon : float = 1e-8 ,
3948 ) -> None :
4049 self .q = q
4150 self .num_classes = num_classes
4251 self .noise_tolerance = noise_tolerance
4352 self .a = a
4453 self .b = b
54+ self .c = c
55+ self .lambda_reg_weight = lambda_reg_weight
56+ self .lambda_entropy_weight = lambda_entropy_weight
57+ self .lambda_sum_weight = lambda_sum_weight
4558 self .epsilon = epsilon
4659 super ().__init__ (name = name )
4760
48- def penalizer (self , lms : tf .Tensor , lambdas : tf .Tensor ) -> tf .Tensor :
49- """Compute the penalizer term for reliability regularization."""
50- x = lambdas - lms
51- return tf .maximum (1 / (1 - self .a ) * x * tf .exp ((x - 1 ) / self .b ), 0 )
52-
5361 def call (
5462 self ,
5563 y_true : tf .Tensor ,
5664 y_pred : tf .Tensor ,
5765 lambda_r : tf .Tensor ,
5866 labeler_mask : tf .Tensor ,
5967 ) -> tf .Tensor :
68+ # Cast inputs to target data type
69+ y_true = tf .cast (y_true , TARGET_DATA_TYPE )
70+ y_pred = tf .cast (y_pred , TARGET_DATA_TYPE )
71+ lambda_r = tf .cast (lambda_r , TARGET_DATA_TYPE )
72+
6073 y_pred = tf .clip_by_value (y_pred , self .epsilon , 1.0 - self .epsilon )
6174 lambda_r = tf .clip_by_value (lambda_r , self .epsilon , 1.0 - self .epsilon )
6275
63- reg_term = self . penalizer (labeler_mask , lambda_r )
76+ reg_term = reliability_penalizer (labeler_mask , lambda_r , self . a , self . b , self . c )
6477
6578 y_pred_exp = tf .expand_dims (y_pred , axis = - 1 )
6679 y_pred_exp = tf .tile (y_pred_exp , [1 , 1 , 1 , 1 , tf .shape (y_true )[- 1 ]])
@@ -78,7 +91,28 @@ def call(
7891 (1.0 - tf .pow (self .noise_tolerance , self .q )) / (self .q + self .epsilon )
7992 )
8093
81- total_loss = tf .reduce_mean (term1 + term2 ) + reg_term
94+ # Only compute regularization terms for valid labelers
95+ valid_lambda_r = lambda_r * tf .expand_dims (tf .expand_dims (labeler_mask , 1 ), 1 )
96+ lambda_reg = self .lambda_reg_weight * tf .reduce_mean (
97+ tf .square (valid_lambda_r - 0.5 )
98+ )
99+
100+ lambda_entropy = - self .lambda_entropy_weight * tf .reduce_mean (
101+ valid_lambda_r * tf .math .log1p (valid_lambda_r )
102+ + (1 - valid_lambda_r ) * tf .math .log1p (1 - valid_lambda_r )
103+ )
104+
105+ lambda_sum = self .lambda_sum_weight * tf .reduce_mean (
106+ tf .square (tf .reduce_sum (valid_lambda_r , axis = - 1 ) - 1.0 )
107+ )
108+
109+ total_loss = (
110+ tf .reduce_mean (term1 + term2 )
111+ + reg_term
112+ + lambda_reg
113+ + lambda_entropy
114+ + lambda_sum
115+ )
82116
83117 total_loss = tf .where (
84118 tf .math .is_nan (total_loss ),
@@ -99,11 +133,14 @@ def get_config(
99133 ** base_config ,
100134 "q" : self .q ,
101135 "b" : self .b ,
136+ "lambda_reg_weight" : self .lambda_reg_weight ,
137+ "lambda_entropy_weight" : self .lambda_entropy_weight ,
138+ "lambda_sum_weight" : self .lambda_sum_weight ,
102139 "epsilon" : self .epsilon ,
103140 }
104141
105142
106- class TcgeFeatures (Loss ):
143+ class TgceFeatures (Loss ):
107144 """
108145 Truncated generalized cross entropy for semantic segmentation loss
109146 with feature-based reliability (reliability map from bottleneck features).
@@ -210,7 +247,7 @@ def get_config(
210247 }
211248
212249
213- class TcgePixel (Loss ):
250+ class TgcePixel (Loss ):
214251 """
215252 Truncated generalized cross entropy for semantic segmentation loss
216253 with pixel-wise reliability (full resolution reliability map).
0 commit comments