Skip to content

Commit ee61b9d

Browse files
author
Jencir Lee
committed
add sum operator
1 parent 662f403 commit ee61b9d

File tree

2 files changed

+39
-7
lines changed

2 files changed

+39
-7
lines changed

micrograd/engine.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11

22
from numpy import (ndarray, nan, ones, zeros, full,
33
shape as get_shape, where, sum as npsum, mean,
4-
log1p, arctanh)
4+
log1p, arctanh, broadcast_arrays, expand_dims)
55

66
class Value:
77
""" stores a single scalar value and its gradient """
@@ -145,6 +145,24 @@ def _backward():
145145

146146
return out
147147

148+
def sum(self, axis=None):
149+
out = Value(npsum(self.data, axis=axis), (self,), 'sum')
150+
151+
new_shape = [h if j in axis else 1
152+
for j, h in enumerate(self.shape)]
153+
m_new = ones(new_shape)
154+
155+
def _forward(**kwds):
156+
out.data = npsum(self.data, axis=axis)
157+
out._forward = _forward
158+
159+
def _backward():
160+
self.grad += broadcast_arrays(expand_dims(out.grad, axis),
161+
m_new)[0]
162+
out._backward = _backward
163+
164+
return out
165+
148166
def build_topology(self):
149167
# topological order all of the children in the graph
150168
if not hasattr(self, 'topo'):

tests/test_engine.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,15 @@
44

55
class AutodiffTest(TestCase):
66

7-
def test_basic(self):
7+
def test_sanity_check(self):
88

9+
a = Value(shape=(2,), name='a')
10+
a.forward(a=array([2, 3]))
11+
a.backward()
12+
self.assertTrue(allclose(a.grad, [1, 1]))
13+
14+
def test_basic(self):
15+
# test arithmetic, relu
916
a = Value(shape=(2,), name='a')
1017
b = Value(shape=(2,), name='b')
1118
c = (a + 2).relu() * b ** 2
@@ -16,6 +23,7 @@ def test_basic(self):
1623
self.assertTrue(allclose(b.grad, [0, 18]))
1724
self.assertTrue(allclose(c.grad, [1, 1]))
1825

26+
# test log1p
1927
d = a.log1p()
2028
d.forward(a=array([2, 3]))
2129
d.backward()
@@ -27,25 +35,31 @@ def test_basic(self):
2735
self.assertTrue(allclose(d.data, [nan, 1.38629436], equal_nan=True))
2836
self.assertTrue(allclose(a.grad, [nan, 0.25], equal_nan=True))
2937

38+
# test transpose
3039
f = Value(shape=(2, 1), name='f')
3140
g = f.T ** 2
3241
g.forward(f=array([[2], [-1]]))
3342
g.backward()
3443
self.assertTrue(allclose(f.grad, [[4], [-2]]))
3544

45+
# test arctanh
3646
h = Value(shape=(5,), name='h')
3747
k = (h * 2).arctanh()
3848
k.forward(h=array([-1, -.5, 0, .5, 1]))
3949
k.backward()
4050
self.assertTrue(allclose(h.grad, [nan, inf, 2, inf, nan],
4151
equal_nan=True))
4252

43-
def test_sanity_check(self):
53+
def test_reduce_ops(self):
4454

45-
a = Value(shape=(2,), name='a')
46-
a.forward(a=array([2, 3]))
47-
a.backward()
48-
self.assertTrue(allclose(a.grad, [1, 1]))
55+
a = Value(shape=(2, 2, 3), name='a')
56+
b = (a.sum(axis=(0, 2)) - 31).relu()
57+
b.forward(a=array([[[1, 2, 3], [4, 5, 6]],
58+
[[7, 8, 9], [10, 11, 12]]]))
59+
b.backward()
60+
self.assertTrue(allclose(b.data, [0, 17]))
61+
self.assertTrue(allclose(a.grad, [[[0, 0, 0], [1, 1, 1]],
62+
[[0, 0, 0], [1, 1, 1]]]))
4963

5064
def test_ops(self):
5165
x = Value(-4.0)

0 commit comments

Comments
 (0)