Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,7 @@
.ipynb_checkpoints/
__pycache__/
*.py[cod]
*$py.class
venv/
.env
.DS_Store
24 changes: 23 additions & 1 deletion micrograd/engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@

import math
class Value:
""" stores a single scalar value and its gradient """

Expand Down Expand Up @@ -42,6 +42,28 @@ def _backward():

return out

def tanh(self):
x = self.data
t = (math.exp(2*x) - 1)/(math.exp(2*x) + 1)
out = Value(t, (self,), 'tanh')

def _backward():
self.grad += (1 - t**2) * out.grad
out._backward = _backward

return out

def sigmoid(self):
x = self.data
s = 1 / (1 + math.exp(-x))
out = Value(s, (self,), 'sigmoid')

def _backward():
self.grad += (s * (1 - s)) * out.grad
out._backward = _backward

return out

def relu(self):
out = Value(0 if self.data < 0 else self.data, (self,), 'ReLU')

Expand Down
50 changes: 50 additions & 0 deletions micrograd/optim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from micrograd.engine import Value

class Optimizer:
def __init__(self, parameters):
self.parameters = [p for p in parameters if p.grad != 0 or True] # keep all for now

def step(self):
raise NotImplementedError

def zero_grad(self):
for p in self.parameters:
p.grad = 0

class SGD(Optimizer):
def __init__(self, parameters, lr=0.01, momentum=0.0):
super().__init__(parameters)
self.lr = lr
self.momentum = momentum
self.velocities = {p: 0 for p in self.parameters}

def step(self):
for p in self.parameters:
v = self.velocities[p]
v = self.momentum * v - self.lr * p.grad
self.velocities[p] = v
p.data += v

class Adam(Optimizer):
def __init__(self, parameters, lr=0.001, betas=(0.9, 0.999), eps=1e-8):
super().__init__(parameters)
self.lr = lr
self.beta1, self.beta2 = betas
self.eps = eps
self.m = {p: 0 for p in self.parameters}
self.v = {p: 0 for p in self.parameters}
self.t = 0

def step(self):
self.t += 1
for p in self.parameters:
if p.grad == 0:
continue

self.m[p] = self.beta1 * self.m[p] + (1 - self.beta1) * p.grad
self.v[p] = self.beta2 * self.v[p] + (1 - self.beta2) * (p.grad ** 2)

m_hat = self.m[p] / (1 - self.beta1 ** self.t)
v_hat = self.v[p] / (1 - self.beta2 ** self.t)

p.data -= self.lr * m_hat / (v_hat**0.5 + self.eps)
41 changes: 41 additions & 0 deletions test/test_optim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import unittest
import torch
from micrograd.engine import Value
from micrograd.optim import SGD, Adam

class TestOptim(unittest.TestCase):

def test_sgd(self):
# Simple quadratic optimization: y = (x - 3)^2
# Minimum at x = 3

# micrograd
x = Value(0.0)
optimizer = SGD([x], lr=0.1)

for _ in range(100):
optimizer.zero_grad()
loss = (x - 3)**2
loss.backward()
optimizer.step()

self.assertAlmostEqual(x.data, 3.0, delta=0.1)

def test_adam(self):
# Simple quadratic optimization: y = (x - 3)^2
# Minimum at x = 3

# micrograd
x = Value(0.0)
optimizer = Adam([x], lr=0.1)

for _ in range(300):
optimizer.zero_grad()
loss = (x - 3)**2
loss.backward()
optimizer.step()

self.assertAlmostEqual(x.data, 3.0, delta=0.01)

if __name__ == '__main__':
unittest.main()