@@ -40,7 +40,7 @@ def __init__( # pylint: disable=too-many-arguments
4040 b : float = 0.7 ,
4141 lambda_reg_weight : float = 0.1 ,
4242 lambda_entropy_weight : float = 0.1 ,
43- lambda_sum_weight : float = 0.1 ,
43+ lambda_sum_weight : float | None = None ,
4444 epsilon : float = 1e-8 ,
4545 ) -> None :
4646 self .q = q
@@ -98,8 +98,11 @@ def call(
9898 + (1 - valid_lambda_r ) * tf .math .log1p (1 - valid_lambda_r )
9999 )
100100
101- lambda_sum = self .lambda_sum_weight * tf .reduce_mean (
102- tf .square (tf .reduce_sum (valid_lambda_r , axis = - 1 ) - 1.0 )
101+ lambda_sum = (
102+ self .lambda_sum_weight
103+ * tf .reduce_mean (tf .square (tf .reduce_sum (valid_lambda_r , axis = - 1 ) - 1.0 ))
104+ if self .lambda_sum_weight is not None
105+ else 0.0
103106 )
104107
105108 total_loss = (
@@ -153,7 +156,7 @@ def __init__( # pylint: disable=too-many-arguments
153156 b : float = 0.7 ,
154157 lambda_reg_weight : float = 0.1 ,
155158 lambda_entropy_weight : float = 0.1 ,
156- lambda_sum_weight : float = 0.1 ,
159+ lambda_sum_weight : float | None = None ,
157160 epsilon : float = 1e-8 ,
158161 ) -> None :
159162 self .a = a
@@ -217,8 +220,11 @@ def call(
217220 + (1 - valid_lambda_r ) * tf .math .log1p (1 - valid_lambda_r )
218221 )
219222
220- lambda_sum = self .lambda_sum_weight * tf .reduce_mean (
221- tf .square (tf .reduce_sum (valid_lambda_r , axis = - 1 ) - 1.0 )
223+ lambda_sum = (
224+ self .lambda_sum_weight
225+ * tf .reduce_mean (tf .square (tf .reduce_sum (valid_lambda_r , axis = - 1 ) - 1.0 ))
226+ if self .lambda_sum_weight is not None
227+ else 0.0
222228 )
223229
224230 total_loss = (
@@ -271,7 +277,7 @@ def __init__( # pylint: disable=too-many-arguments
271277 b : float = 0.7 ,
272278 lambda_reg_weight : float = 0.1 ,
273279 lambda_entropy_weight : float = 0.1 ,
274- lambda_sum_weight : float = 0.1 ,
280+ lambda_sum_weight : float | None = None ,
275281 epsilon : float = 1e-8 ,
276282 ) -> None :
277283 self .a = a
@@ -330,8 +336,11 @@ def call(
330336 + (1 - valid_lambda_r ) * tf .math .log1p (1 - valid_lambda_r )
331337 )
332338
333- lambda_sum = self .lambda_sum_weight * tf .reduce_mean (
334- tf .square (tf .reduce_sum (valid_lambda_r , axis = - 1 ) - 1.0 )
339+ lambda_sum = (
340+ self .lambda_sum_weight
341+ * tf .reduce_mean (tf .square (tf .reduce_sum (valid_lambda_r , axis = - 1 ) - 1.0 ))
342+ if self .lambda_sum_weight is not None
343+ else 0.0
335344 )
336345
337346 total_loss = (
0 commit comments