Skip to content

Commit 6ade78b

Browse files
authored
Merge pull request #103 from lucasb-eyer/bce-sumfeatures
BCE Criterion can optionally sum across an axis (features).
2 parents 348abec + f6d9dad commit 6ade78b

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

DeepFried2/criteria/BCECriterion.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,26 @@ class BCECriterion(df.Criterion):
66
Like cross-entropy but also penalizing label-zero predictions directly.
77
"""
88

9-
def __init__(self, clip=None):
9+
def __init__(self, clip=None, sumaxis=None):
1010
"""
1111
- clip: clip inputs to [clip, 1-clip] to avoid potential numerical issues.
12+
- sumaxis: if we want to sum along one or more axes to get a per-sample
13+
BCE in case each sample is made of more than one BCE
14+
(e.g. each pixel in an image.)
1215
"""
1316
df.Criterion.__init__(self)
1417
self.clip = clip
18+
self.sumaxis = sumaxis
1519

1620
def symb_forward(self, symb_input, symb_target):
1721
self._assert_same_dim(symb_input, symb_target)
1822

1923
if self.clip is not None:
2024
symb_input = df.T.clip(symb_input, self.clip, 1-self.clip)
2125

22-
return df.T.nnet.binary_crossentropy(symb_input, symb_target)
26+
bce = df.T.nnet.binary_crossentropy(symb_input, symb_target)
27+
28+
if self.sumaxis is not None:
29+
bce = df.T.sum(bce, self.sumaxis)
30+
31+
return bce

0 commit comments

Comments
 (0)