Skip to content

Commit 87cf988

Browse files
First student run.
1 parent 691161b commit 87cf988

File tree

1 file changed

+63
-0
lines changed

1 file changed

+63
-0
lines changed

examples/mnist_small.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import sys, pathlib
2+
# (Ugly) workaround to enable importing from parent folder without too much hassle.
3+
sys.path.append(str(pathlib.Path(__file__).parent.parent))
4+
5+
import csv
6+
import numpy as np
7+
from nn import NeuralNetwork, Layer, LeakyReLU, CrossEntropyLoss
8+
9+
TRAIN_FILE = pathlib.Path(__file__).parent / "mnistdata/mnist_train.csv"
10+
TEST_FILE = pathlib.Path(__file__).parent / "mnistdata/mnist_test.csv"
11+
12+
def load_data(filepath, delimiter=",", dtype=float):
13+
"""Load a numerical numpy array from a file."""
14+
15+
print(f"Loading {filepath}...")
16+
with open(filepath, "r") as f:
17+
data_iterator = csv.reader(f, delimiter=delimiter)
18+
data_list = list(data_iterator)
19+
data = np.asarray(data_list, dtype=dtype)
20+
print("Done.")
21+
return data
22+
23+
def to_col(x):
24+
return x.reshape((x.size, 1))
25+
26+
def test(net, test_data):
27+
correct = 0
28+
for i, test_row in enumerate(test_data):
29+
if not i%1000:
30+
print(i)
31+
32+
t = test_row[0]
33+
x = to_col(test_row[1:])
34+
out = net.forward_pass(x)
35+
guess = np.argmax(out)
36+
if t == guess:
37+
correct += 1
38+
39+
return correct/test_data.shape[0]
40+
41+
def train(net, train_data):
42+
for i, train_row in enumerate(train_data):
43+
if not i%1000:
44+
print(i)
45+
46+
net.train(to_col(train_row[1:]), train_row[0])
47+
48+
49+
if __name__ == "__main__":
50+
layers = [
51+
Layer(784, 10, LeakyReLU()),
52+
]
53+
net = NeuralNetwork(layers, CrossEntropyLoss(), 0.001)
54+
55+
test_data = load_data(TEST_FILE, delimiter=",", dtype=int)
56+
accuracy = test(net, test_data)
57+
print(f"Accuracy is {100*accuracy:.2f}%") # Expected to be around 10%
58+
59+
train_data = load_data(TRAIN_FILE, delimiter=",", dtype=int)
60+
train(net, train_data)
61+
62+
accuracy = test(net, test_data)
63+
print(f"Accuracy is {100*accuracy:.2f}%")

0 commit comments

Comments
 (0)