Skip to content

Commit 3687daa

Browse files
author
Jencir Lee
committed
fix reduce ops with negative axis
1 parent a35faa5 commit 3687daa

File tree

3 files changed

+96
-9
lines changed

3 files changed

+96
-9
lines changed

micrograd/engine.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -202,25 +202,23 @@ def sum(self, axis=None):
202202

203203
out = Value(self.data.sum(axis=axis), (self,), 'sum')
204204

205+
de_neg = lambda x: self.ndim + x if x < 0 else x
205206
if axis is None:
206-
new_shape = self.shape
207-
elif isinstance(axis, int):
208-
new_shape = [h if j == axis else 1
209-
for j, h in enumerate(self.shape)]
207+
expand_axis = tuple(range(self.data.ndim))
208+
elif _shape(axis) == ():
209+
expand_axis = de_neg(axis)
210210
else:
211-
new_shape = [h if j in axis else 1
212-
for j, h in enumerate(self.shape)]
213-
m_new = ones(new_shape)
211+
expand_axis = tuple(map(de_neg, axis))
214212

215-
expand_axis = tuple(range(self.data.ndim)) if axis is None else axis
213+
arr_orig_shape = ones(self.shape)
216214

217215
def _forward(**kwds):
218216
out.data = self.data.sum(axis=axis)
219217
out._forward = _forward
220218

221219
def _backward():
222220
self.grad += broadcast_arrays(expand_dims(out.grad, expand_axis),
223-
m_new)[0]
221+
arr_orig_shape)[0]
224222
out._backward = _backward
225223

226224
return out

tests/test_engine.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,36 @@ def test_sum_op(self):
9494
c.backward()
9595
self.assertTrue(allclose(a.grad, 0))
9696

97+
def test_sum_op_neg_axis(self):
98+
99+
a = Value(shape=(2, 2, 3), name='a')
100+
b = (a.sum(axis=(0, -1)) - 31).relu()
101+
b.forward(a=array([[[1, 2, 3], [4, 5, 6]],
102+
[[7, 8, 9], [10, 11, 12]]]))
103+
b.backward()
104+
self.assertTrue(allclose(b.data, [0, 17]))
105+
self.assertTrue(allclose(a.grad, [[[0, 0, 0], [1, 1, 1]],
106+
[[0, 0, 0], [1, 1, 1]]]))
107+
108+
b = (a.sum(axis=-3) - 10).relu()
109+
b.forward(a=array([[[1, 2, 3], [4, 5, 6]],
110+
[[7, 8, 9], [10, 11, 12]]]))
111+
b.backward()
112+
self.assertTrue(allclose(a.grad, [[[0, 0, 1], [1, 1, 1]],
113+
[[0, 0, 1], [1, 1, 1]]]))
114+
115+
b = (a.sum() - 77).relu()
116+
c = (a.sum() - 79).relu()
117+
b.forward(a=array([[[1, 2, 3], [4, 5, 6]],
118+
[[7, 8, 9], [10, 11, 12]]]))
119+
b.backward()
120+
self.assertTrue(allclose(a.grad, 1))
121+
122+
c.forward(a=array([[[1, 2, 3], [4, 5, 6]],
123+
[[7, 8, 9], [10, 11, 12]]]))
124+
c.backward()
125+
self.assertTrue(allclose(a.grad, 0))
126+
97127
def test_mean_op(self):
98128

99129
a = Value(shape=(2, 2, 3), name='a')
@@ -129,6 +159,41 @@ def test_mean_op(self):
129159
c.backward()
130160
self.assertTrue(allclose(a.grad, 0))
131161

162+
def test_mean_op_neg_axis(self):
163+
164+
a = Value(shape=(2, 2, 3), name='a')
165+
self.assertTrue(a.mean()._op == '*')
166+
167+
b = (a.mean(axis=(0, -1)) - 6).relu()
168+
b.forward(a=array([[[1, 2, 3], [4, 5, 6]],
169+
[[7, 8, 9], [10, 11, 12]]]))
170+
b.backward()
171+
y = 1 / 6
172+
self.assertTrue(allclose(b.data, [0, 2]))
173+
self.assertTrue(allclose(a.grad, [[[0, 0, 0], [y, y, y]],
174+
[[0, 0, 0], [y, y, y]]]))
175+
176+
b = (a.mean(axis=-3) - 5).relu()
177+
b.forward(a=array([[[1, 2, 3], [4, 5, 6]],
178+
[[7, 8, 9], [10, 11, 12]]]))
179+
b.backward()
180+
y = .5
181+
self.assertTrue(allclose(a.grad, [[[0, 0, y], [y, y, y]],
182+
[[0, 0, y], [y, y, y]]]))
183+
184+
b = (a.mean() - 6).relu()
185+
c = (a.mean() - 7).relu()
186+
b.forward(a=array([[[1, 2, 3], [4, 5, 6]],
187+
[[7, 8, 9], [10, 11, 12]]]))
188+
b.backward()
189+
y = 1 / 12
190+
self.assertTrue(allclose(a.grad, y))
191+
192+
c.forward(a=array([[[1, 2, 3], [4, 5, 6]],
193+
[[7, 8, 9], [10, 11, 12]]]))
194+
c.backward()
195+
self.assertTrue(allclose(a.grad, 0))
196+
132197
def test_tensordot_op(self):
133198

134199
a = Value(empty((2, 3, 4)))

tests/test_vs_torch.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,27 @@ def test_reduce_ops(self):
166166
c2.backward()
167167
self.assertTrue(allclose(c.data, c2.data))
168168
self.assertTrue(allclose(a.grad, a2.grad))
169+
170+
def test_reduce_ops_neg_axis(self):
171+
172+
a = Value(array([[[1, 2, -2], [2, 1, 0]],
173+
[[-2, 1, 0], [3, 2, 1]]]))
174+
b = a.mean(axis=(0, -1)).relu().sum()
175+
c = a.mean(axis=(-3, -2)).relu().mean()
176+
177+
a2 = Tensor([[[1, 2, -2], [2, 1, 0]],
178+
[[-2, 1, 0], [3, 2, 1]]])
179+
a2.requires_grad = True
180+
b2 = a2.mean(axis=(0, -1)).relu().sum()
181+
c2 = a2.mean(axis=(-3, -2)).relu().mean()
182+
183+
b.backward()
184+
b2.backward()
185+
self.assertTrue(allclose(b.data, b2.data))
186+
self.assertTrue(allclose(a.grad, a2.grad))
187+
188+
a2.grad = None
189+
c.backward()
190+
c2.backward()
191+
self.assertTrue(allclose(c.data, c2.data))
192+
self.assertTrue(allclose(a.grad, a2.grad))

0 commit comments

Comments
 (0)