diff --git a/micrograd/engine.py b/micrograd/engine.py index afd82cc5..ec531b2d 100644 --- a/micrograd/engine.py +++ b/micrograd/engine.py @@ -2,53 +2,31 @@ class Value: """ stores a single scalar value and its gradient """ - def __init__(self, data, _children=(), _op=''): + def __init__(self, data, _children=(), _op='', _local_grads=()): self.data = data self.grad = 0 # internal variables used for autograd graph construction - self._backward = lambda: None - self._prev = set(_children) + self._prev = _children self._op = _op # the op that produced this node, for graphviz / debugging / etc + self._local_grads = _local_grads # local derivative of this node w.r.t. its children def __add__(self, other): other = other if isinstance(other, Value) else Value(other) - out = Value(self.data + other.data, (self, other), '+') - - def _backward(): - self.grad += out.grad - other.grad += out.grad - out._backward = _backward - + out = Value(self.data + other.data, (self, other), '+', (1, 1)) return out def __mul__(self, other): other = other if isinstance(other, Value) else Value(other) - out = Value(self.data * other.data, (self, other), '*') - - def _backward(): - self.grad += other.data * out.grad - other.grad += self.data * out.grad - out._backward = _backward - + out = Value(self.data * other.data, (self, other), '*', (other.data, self.data)) return out def __pow__(self, other): assert isinstance(other, (int, float)), "only supporting int/float powers for now" - out = Value(self.data**other, (self,), f'**{other}') - - def _backward(): - self.grad += (other * self.data**(other-1)) * out.grad - out._backward = _backward - + out = Value(self.data**other, (self,), f'**{other}', (other * self.data**(other-1),)) return out def relu(self): - out = Value(0 if self.data < 0 else self.data, (self,), 'ReLU') - - def _backward(): - self.grad += (out.data > 0) * out.grad - out._backward = _backward - + out = Value(0 if self.data < 0 else self.data, (self,), 'ReLU', (float(self.data > 0),)) return out def backward(self): @@ -67,7 +45,8 @@ def build_topo(v): # go one variable at a time and apply the chain rule to get its gradient self.grad = 1 for v in reversed(topo): - v._backward() + for child, local_grad in zip(v._prev, v._local_grads): + child.grad += local_grad * v.grad def __neg__(self): # -self return self * -1