Skip to content

Commit 56b130e

Browse files
committed
ClassNLLCriterion accepts distribution as target.
1 parent 40b8e28 commit 56b130e

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

DeepFried2/criteria/ClassNLLCriterion.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,12 @@
33

44
class ClassNLLCriterion:
55
def symb_forward(self, symb_input, symb_targets):
6-
int_targets = _T.cast(symb_targets, 'int32')
7-
return _T.mean(-_T.log(symb_input[_T.arange(symb_targets.shape[0]), int_targets]))
6+
if symb_targets.ndim == 1:
7+
# This is the case when `symb_targets` are 1-hot encoding indices.
8+
int_targets = _T.cast(symb_targets, 'int32')
9+
return _T.mean(-_T.log(symb_input[_T.arange(symb_targets.shape[0]), int_targets]))
10+
elif symb_targets.ndim == symb_input.ndim:
11+
# This is the case when both are full distributions.
12+
return _T.mean(-_T.sum(symb_targets * _T.log(symb_input), axis=symb_input.ndim-1))
13+
else:
14+
assert False, "Mismatch in dimensionalities of `ClassNLLCriterion` input and targets."

0 commit comments

Comments
 (0)