Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 9 additions & 30 deletions micrograd/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,53 +2,31 @@
class Value:
""" stores a single scalar value and its gradient """

def __init__(self, data, _children=(), _op=''):
def __init__(self, data, _children=(), _op='', _local_grads=()):
self.data = data
self.grad = 0
# internal variables used for autograd graph construction
self._backward = lambda: None
self._prev = set(_children)
self._prev = _children
self._op = _op # the op that produced this node, for graphviz / debugging / etc
self._local_grads = _local_grads # local derivative of this node w.r.t. its children

def __add__(self, other):
other = other if isinstance(other, Value) else Value(other)
out = Value(self.data + other.data, (self, other), '+')

def _backward():
self.grad += out.grad
other.grad += out.grad
out._backward = _backward

out = Value(self.data + other.data, (self, other), '+', (1, 1))
return out

def __mul__(self, other):
other = other if isinstance(other, Value) else Value(other)
out = Value(self.data * other.data, (self, other), '*')

def _backward():
self.grad += other.data * out.grad
other.grad += self.data * out.grad
out._backward = _backward

out = Value(self.data * other.data, (self, other), '*', (other.data, self.data))
return out

def __pow__(self, other):
assert isinstance(other, (int, float)), "only supporting int/float powers for now"
out = Value(self.data**other, (self,), f'**{other}')

def _backward():
self.grad += (other * self.data**(other-1)) * out.grad
out._backward = _backward

out = Value(self.data**other, (self,), f'**{other}', (other * self.data**(other-1),))
return out

def relu(self):
out = Value(0 if self.data < 0 else self.data, (self,), 'ReLU')

def _backward():
self.grad += (out.data > 0) * out.grad
out._backward = _backward

out = Value(0 if self.data < 0 else self.data, (self,), 'ReLU', (float(self.data > 0),))
return out

def backward(self):
Expand All @@ -67,7 +45,8 @@ def build_topo(v):
# go one variable at a time and apply the chain rule to get its gradient
self.grad = 1
for v in reversed(topo):
v._backward()
for child, local_grad in zip(v._prev, v._local_grads):
child.grad += local_grad * v.grad

def __neg__(self): # -self
return self * -1
Expand Down