File tree Expand file tree Collapse file tree 1 file changed +11
-2
lines changed
Expand file tree Collapse file tree 1 file changed +11
-2
lines changed Original file line number Diff line number Diff line change @@ -6,17 +6,26 @@ class BCECriterion(df.Criterion):
66 Like cross-entropy but also penalizing label-zero predictions directly.
77 """
88
9- def __init__ (self , clip = None ):
9+ def __init__ (self , clip = None , sumaxis = None ):
1010 """
1111 - clip: clip inputs to [clip, 1-clip] to avoid potential numerical issues.
12+ - sumaxis: if we want to sum along one or more axes to get a per-sample
13+ BCE in case each sample is made of more than one BCE
14+ (e.g. each pixel in an image.)
1215 """
1316 df .Criterion .__init__ (self )
1417 self .clip = clip
18+ self .sumaxis = sumaxis
1519
1620 def symb_forward (self , symb_input , symb_target ):
1721 self ._assert_same_dim (symb_input , symb_target )
1822
1923 if self .clip is not None :
2024 symb_input = df .T .clip (symb_input , self .clip , 1 - self .clip )
2125
22- return df .T .nnet .binary_crossentropy (symb_input , symb_target )
26+ bce = df .T .nnet .binary_crossentropy (symb_input , symb_target )
27+
28+ if self .sumaxis is not None :
29+ bce = df .T .sum (bce , self .sumaxis )
30+
31+ return bce
You can’t perform that action at this time.
0 commit comments