@@ -229,7 +229,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple:
229229 return total_loss .mean (), loss_components
230230
231231 def _calculate_implication_loss (
232- self , l : torch .Tensor , r : torch .Tensor
232+ self , l_ : torch .Tensor , r : torch .Tensor
233233 ) -> torch .Tensor :
234234 """
235235 Calculate implication loss based on T-norm and other parameters.
@@ -241,17 +241,17 @@ def _calculate_implication_loss(
241241 Returns:
242242 torch.Tensor: Calculated implication loss.
243243 """
244- assert not l .isnan ().any (), (
245- f"l contains NaN values - l.shape: { l .shape } , l.isnan().sum(): { l .isnan ().sum ()} , "
246- f"l: { l } "
244+ assert not l_ .isnan ().any (), (
245+ f"l contains NaN values - l.shape: { l_ .shape } , l.isnan().sum(): { l_ .isnan ().sum ()} , "
246+ f"l: { l_ } "
247247 )
248248 assert not r .isnan ().any (), (
249249 f"r contains NaN values - r.shape: { r .shape } , r.isnan().sum(): { r .isnan ().sum ()} , "
250250 f"r: { r } "
251251 )
252252 if self .pos_scalar != 1 :
253- l = (
254- torch .pow (l + self .eps , 1 / self .pos_scalar )
253+ l_ = (
254+ torch .pow (l_ + self .eps , 1 / self .pos_scalar )
255255 - math .pow (self .eps , 1 / self .pos_scalar )
256256 ) / (
257257 math .pow (1 + self .eps , 1 / self .pos_scalar )
@@ -269,21 +269,21 @@ def _calculate_implication_loss(
269269 # for each implication I, calculate 1 - I(l, 1-one_min_r)
270270 # for S-implications, this is equivalent to the t-norm
271271 if self .fuzzy_implication in ["reichenbach" , "rc" ]:
272- individual_loss = l * one_min_r
272+ individual_loss = l_ * one_min_r
273273 # xu19 (from Xu et al., 2019: Semantic loss) is not a fuzzy implication, but behaves similar to the Reichenbach
274274 # implication
275275 elif self .fuzzy_implication == "xu19" :
276- individual_loss = - torch .log (1 - l * one_min_r )
276+ individual_loss = - torch .log (1 - l_ * one_min_r )
277277 elif self .fuzzy_implication in ["lukasiewicz" , "lk" ]:
278- individual_loss = torch .relu (l + one_min_r - 1 )
278+ individual_loss = torch .relu (l_ + one_min_r - 1 )
279279 elif self .fuzzy_implication in ["kleene_dienes" , "kd" ]:
280- individual_loss = torch .min (l , 1 - r )
280+ individual_loss = torch .min (l_ , 1 - r )
281281 elif self .fuzzy_implication in ["goedel" , "g" ]:
282- individual_loss = torch .where (l <= r , 0 , one_min_r )
282+ individual_loss = torch .where (l_ <= r , 0 , one_min_r )
283283 elif self .fuzzy_implication in ["reverse-goedel" , "rg" ]:
284- individual_loss = torch .where (l <= r , 0 , l )
284+ individual_loss = torch .where (l_ <= r , 0 , l_ )
285285 elif self .fuzzy_implication in ["binary" , "b" ]:
286- individual_loss = torch .where (l <= r , 0 , 1 ).to (dtype = l .dtype )
286+ individual_loss = torch .where (l_ <= r , 0 , 1 ).to (dtype = l_ .dtype )
287287 else :
288288 raise NotImplementedError (
289289 f"Unknown fuzzy implication { self .fuzzy_implication } "
@@ -453,8 +453,8 @@ def _build_implication_filter(label_names: List, hierarchy: dict) -> torch.Tenso
453453
454454def _build_dense_filter (sparse_filter : torch .Tensor , n_labels : int ) -> torch .Tensor :
455455 res = torch .zeros ((n_labels , n_labels ), dtype = torch .bool )
456- for l , r in sparse_filter :
457- res [l , r ] = True
456+ for l_ , r in sparse_filter :
457+ res [l_ , r ] = True
458458 return res
459459
460460
@@ -511,8 +511,8 @@ def _build_disjointness_filter(
511511 random_labels = torch .randint (0 , 2 , (10 , 997 ))
512512 for agg in ["sum" , "max" , "mean" , "log-mean" ]:
513513 loss .violations_per_cls_aggregator = agg
514- l = loss (random_preds , random_labels )
515- print (f"Loss with { agg } aggregation for random input:" , l )
514+ l_ = loss (random_preds , random_labels )
515+ print (f"Loss with { agg } aggregation for random input:" , l_ )
516516
517517 # simplified example for ontology with 4 classes, A -> B, B -> C, D -> C, B and D disjoint
518518 loss .implication_filter_l = torch .tensor (
@@ -528,5 +528,5 @@ def _build_disjointness_filter(
528528 labels = [[0 , 1 , 1 , 0 ], [0 , 0 , 1 , 1 ]]
529529 for agg in ["sum" , "max" , "mean" , "log-mean" ]:
530530 loss .violations_per_cls_aggregator = agg
531- l = loss (preds , torch .tensor (labels ))
532- print (f"Loss with { agg } aggregation for simple input:" , l )
531+ l_ = loss (preds , torch .tensor (labels ))
532+ print (f"Loss with { agg } aggregation for simple input:" , l_ )
0 commit comments