Skip to content

Commit 01a0104

Browse files
author
Jencir Lee
committed
add log operator
1 parent ee39be2 commit 01a0104

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

micrograd/engine.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11

22
from numpy import (array, ndarray, nan, ones, zeros, full,
33
shape as get_shape, where, sum as npsum,
4-
log1p, arctanh, broadcast_arrays, expand_dims,
4+
log, log1p, arctanh, broadcast_arrays, expand_dims,
55
prod, tensordot, isnan, all as npall)
66
from numbers import Number
77
from warnings import warn
@@ -133,6 +133,20 @@ def _backward():
133133

134134
return out
135135

136+
def log(self):
137+
out = Value(log(self.data), (self,), 'log')
138+
139+
def _forward(**kwds):
140+
out.data = log(self.data)
141+
out._forward = _forward
142+
143+
def _backward(**kwds):
144+
valid_data = where(self.data >= 0, self.data, nan)
145+
self.grad += 1 / valid_data * out.grad
146+
out._backward = _backward
147+
148+
return out
149+
136150
def log1p(self):
137151
out = Value(log1p(self.data), (self,), 'log1p')
138152

@@ -141,7 +155,7 @@ def _forward(**kwds):
141155
out._forward = _forward
142156

143157
def _backward():
144-
valid_data = where(self.data >= 0, self.data, nan)
158+
valid_data = where(self.data >= -1, self.data, nan)
145159
self.grad += 1 / (1 + valid_data) * out.grad
146160
out._backward = _backward
147161

tests/test_engine.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,13 @@ def test_basic(self):
2323
self.assertTrue(allclose(b.grad, [0, 18]))
2424
self.assertTrue(allclose(c.grad, [1, 1]))
2525

26+
# test log
27+
d = a.log()
28+
d.forward(a=array([2, 3]))
29+
d.backward()
30+
self.assertTrue(allclose(d.data, [0.69314718, 1.09861229]))
31+
self.assertTrue(allclose(a.grad, [.5, 1 / 3]))
32+
2633
# test log1p
2734
d = a.log1p()
2835
d.forward(a=array([2, 3]))

0 commit comments

Comments
 (0)