File tree Expand file tree Collapse file tree 1 file changed +9
-2
lines changed Expand file tree Collapse file tree 1 file changed +9
-2
lines changed Original file line number Diff line number Diff line change 33
44class ClassNLLCriterion :
55 def symb_forward (self , symb_input , symb_targets ):
6- int_targets = _T .cast (symb_targets , 'int32' )
7- return _T .mean (- _T .log (symb_input [_T .arange (symb_targets .shape [0 ]), int_targets ]))
6+ if symb_targets .ndim == 1 :
7+ # This is the case when `symb_targets` are 1-hot encoding indices.
8+ int_targets = _T .cast (symb_targets , 'int32' )
9+ return _T .mean (- _T .log (symb_input [_T .arange (symb_targets .shape [0 ]), int_targets ]))
10+ elif symb_targets .ndim == symb_input .ndim :
11+ # This is the case when both are full distributions.
12+ return _T .mean (- _T .sum (symb_targets * _T .log (symb_input ), axis = symb_input .ndim - 1 ))
13+ else :
14+ assert False , "Mismatch in dimensionalities of `ClassNLLCriterion` input and targets."
You can’t perform that action at this time.
0 commit comments