11
22from numpy import (array , ndarray , nan , ones , zeros , full ,
33 shape as get_shape , where , sum as npsum ,
4- log1p , arctanh , broadcast_arrays , expand_dims ,
4+ log , log1p , arctanh , broadcast_arrays , expand_dims ,
55 prod , tensordot , isnan , all as npall )
66from numbers import Number
77from warnings import warn
@@ -133,6 +133,20 @@ def _backward():
133133
134134 return out
135135
136+ def log (self ):
137+ out = Value (log (self .data ), (self ,), 'log' )
138+
139+ def _forward (** kwds ):
140+ out .data = log (self .data )
141+ out ._forward = _forward
142+
143+ def _backward (** kwds ):
144+ valid_data = where (self .data >= 0 , self .data , nan )
145+ self .grad += 1 / valid_data * out .grad
146+ out ._backward = _backward
147+
148+ return out
149+
136150 def log1p (self ):
137151 out = Value (log1p (self .data ), (self ,), 'log1p' )
138152
@@ -141,7 +155,7 @@ def _forward(**kwds):
141155 out ._forward = _forward
142156
143157 def _backward ():
144- valid_data = where (self .data >= 0 , self .data , nan )
158+ valid_data = where (self .data >= - 1 , self .data , nan )
145159 self .grad += 1 / (1 + valid_data ) * out .grad
146160 out ._backward = _backward
147161
0 commit comments