Skip to content

Commit ea1e3f9

Browse files
Merge pull request #1 from mathspp/teacher-student
Teacher-student experiment
2 parents 691161b + 11fabe9 commit ea1e3f9

File tree

2 files changed

+153
-0
lines changed

2 files changed

+153
-0
lines changed

examples/mnist_small.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import sys, pathlib
2+
# (Ugly) workaround to enable importing from parent folder without too much hassle.
3+
sys.path.append(str(pathlib.Path(__file__).parent.parent))
4+
5+
import csv
6+
import numpy as np
7+
from nn import NeuralNetwork, Layer, LeakyReLU, CrossEntropyLoss
8+
9+
TRAIN_FILE = pathlib.Path(__file__).parent / "mnistdata/mnist_train.csv"
10+
TEST_FILE = pathlib.Path(__file__).parent / "mnistdata/mnist_test.csv"
11+
12+
def load_data(filepath, delimiter=",", dtype=float):
13+
"""Load a numerical numpy array from a file."""
14+
15+
print(f"Loading {filepath}...")
16+
with open(filepath, "r") as f:
17+
data_iterator = csv.reader(f, delimiter=delimiter)
18+
data_list = list(data_iterator)
19+
data = np.asarray(data_list, dtype=dtype)
20+
print("Done.")
21+
return data
22+
23+
def to_col(x):
24+
return x.reshape((x.size, 1))
25+
26+
def test(net, test_data):
27+
correct = 0
28+
for i, test_row in enumerate(test_data):
29+
if not i%1000:
30+
print(i)
31+
32+
t = test_row[0]
33+
x = to_col(test_row[1:])/255
34+
out = net.forward_pass(x)
35+
guess = np.argmax(out)
36+
if t == guess:
37+
correct += 1
38+
39+
return correct/test_data.shape[0]
40+
41+
def train(net, train_data):
42+
for i, train_row in enumerate(train_data):
43+
if not i%1000:
44+
print(i)
45+
46+
net.train(to_col(train_row[1:])/255, train_row[0])
47+
48+
49+
if __name__ == "__main__":
50+
layers = [
51+
Layer(784, 10, LeakyReLU()),
52+
]
53+
net = NeuralNetwork(layers, CrossEntropyLoss(), 0.001)
54+
55+
test_data = load_data(TEST_FILE, delimiter=",", dtype=int)
56+
accuracy = test(net, test_data)
57+
print(f"Accuracy is {100*accuracy:.2f}%") # Expected to be around 10%
58+
59+
train_data = load_data(TRAIN_FILE, delimiter=",", dtype=int)
60+
train(net, train_data)
61+
62+
accuracy = test(net, test_data)
63+
print(f"Accuracy is {100*accuracy:.2f}%")

examples/teacher_student.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import sys, pathlib
2+
# (Ugly) workaround to enable importing from parent folder without too much hassle.
3+
sys.path.append(str(pathlib.Path(__file__).parent.parent))
4+
5+
import csv
6+
import numpy as np
7+
from nn import NeuralNetwork, Layer, LeakyReLU, CrossEntropyLoss, MSELoss
8+
9+
TRAIN_FILE = pathlib.Path(__file__).parent / "mnistdata/mnist_train.csv"
10+
TEST_FILE = pathlib.Path(__file__).parent / "mnistdata/mnist_test.csv"
11+
12+
def load_data(filepath, delimiter=",", dtype=float):
13+
"""Load a numerical numpy array from a file."""
14+
15+
print(f"Loading {filepath}...")
16+
with open(filepath, "r") as f:
17+
data_iterator = csv.reader(f, delimiter=delimiter)
18+
data_list = list(data_iterator)
19+
data = np.asarray(data_list, dtype=dtype)
20+
print("Done.")
21+
return data
22+
23+
def to_col(x):
24+
return x.reshape((x.size, 1))
25+
26+
def test(net, test_data):
27+
correct = 0
28+
for i, test_row in enumerate(test_data):
29+
if not i%1000:
30+
print(i)
31+
32+
t = test_row[0]
33+
x = to_col(test_row[1:])/255
34+
out = net.forward_pass(x)
35+
guess = np.argmax(out)
36+
if t == guess:
37+
correct += 1
38+
39+
return correct/test_data.shape[0]
40+
41+
def train(net, train_data):
42+
for i, train_row in enumerate(train_data):
43+
if not i%1000:
44+
print(i)
45+
46+
net.train(to_col(train_row[1:])/255, train_row[0])
47+
48+
def train_students(teacher, students, train_data):
49+
for i, train_row in enumerate(train_data):
50+
if not i%1000:
51+
print(i)
52+
53+
x = to_col(train_row[1:])/255
54+
out = teacher.forward_pass(x)
55+
for student in students:
56+
student.train(x, out)
57+
58+
59+
if __name__ == "__main__":
60+
layers = [
61+
Layer(784, 16, LeakyReLU()),
62+
Layer(16, 16, LeakyReLU()),
63+
Layer(16, 10, LeakyReLU()),
64+
]
65+
teacher = NeuralNetwork(layers, CrossEntropyLoss(), 0.03)
66+
students = [
67+
NeuralNetwork([Layer(784, 10, LeakyReLU())], MSELoss(), 0.001),
68+
NeuralNetwork([Layer(784, 10, LeakyReLU())], MSELoss(), 0.003),
69+
NeuralNetwork([Layer(784, 10, LeakyReLU())], MSELoss(), 0.01),
70+
NeuralNetwork([Layer(784, 10, LeakyReLU())], MSELoss(), 0.03),
71+
NeuralNetwork([Layer(784, 10, LeakyReLU())], MSELoss(), 0.1),
72+
NeuralNetwork([Layer(784, 10, LeakyReLU())], MSELoss(), 0.3),
73+
]
74+
75+
test_data = load_data(TEST_FILE, delimiter=",", dtype=int)
76+
accuracy = test(teacher, test_data)
77+
print(f"Accuracy is {100*accuracy:.2f}%") # Expected to be around 10%
78+
79+
train_data = load_data(TRAIN_FILE, delimiter=",", dtype=int)
80+
train(teacher, train_data)
81+
82+
accuracy = test(teacher, test_data)
83+
print(f"Accuracy is {100*accuracy:.2f}%")
84+
85+
print("Training students.")
86+
train_students(teacher, students, train_data)
87+
print("Testing students.")
88+
accuracies = [100*test(student, test_data) for student in students]
89+
print(accuracies)
90+
print(f"Teacher accuracy had been {100*accuracy:.2f}%")

0 commit comments

Comments
 (0)