Skip to content

Commit 11fabe9

Browse files
Train students.
1 parent 4a3f6b0 commit 11fabe9

File tree

1 file changed

+27
-2
lines changed

1 file changed

+27
-2
lines changed

examples/teacher_student.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import csv
66
import numpy as np
7-
from nn import NeuralNetwork, Layer, LeakyReLU, CrossEntropyLoss
7+
from nn import NeuralNetwork, Layer, LeakyReLU, CrossEntropyLoss, MSELoss
88

99
TRAIN_FILE = pathlib.Path(__file__).parent / "mnistdata/mnist_train.csv"
1010
TEST_FILE = pathlib.Path(__file__).parent / "mnistdata/mnist_test.csv"
@@ -45,14 +45,32 @@ def train(net, train_data):
4545

4646
net.train(to_col(train_row[1:])/255, train_row[0])
4747

48+
def train_students(teacher, students, train_data):
49+
for i, train_row in enumerate(train_data):
50+
if not i%1000:
51+
print(i)
52+
53+
x = to_col(train_row[1:])/255
54+
out = teacher.forward_pass(x)
55+
for student in students:
56+
student.train(x, out)
57+
4858

4959
if __name__ == "__main__":
5060
layers = [
5161
Layer(784, 16, LeakyReLU()),
5262
Layer(16, 16, LeakyReLU()),
5363
Layer(16, 10, LeakyReLU()),
5464
]
55-
teacher = NeuralNetwork(layers, CrossEntropyLoss(), 0.01)
65+
teacher = NeuralNetwork(layers, CrossEntropyLoss(), 0.03)
66+
students = [
67+
NeuralNetwork([Layer(784, 10, LeakyReLU())], MSELoss(), 0.001),
68+
NeuralNetwork([Layer(784, 10, LeakyReLU())], MSELoss(), 0.003),
69+
NeuralNetwork([Layer(784, 10, LeakyReLU())], MSELoss(), 0.01),
70+
NeuralNetwork([Layer(784, 10, LeakyReLU())], MSELoss(), 0.03),
71+
NeuralNetwork([Layer(784, 10, LeakyReLU())], MSELoss(), 0.1),
72+
NeuralNetwork([Layer(784, 10, LeakyReLU())], MSELoss(), 0.3),
73+
]
5674

5775
test_data = load_data(TEST_FILE, delimiter=",", dtype=int)
5876
accuracy = test(teacher, test_data)
@@ -63,3 +81,10 @@ def train(net, train_data):
6381

6482
accuracy = test(teacher, test_data)
6583
print(f"Accuracy is {100*accuracy:.2f}%")
84+
85+
print("Training students.")
86+
train_students(teacher, students, train_data)
87+
print("Testing students.")
88+
accuracies = [100*test(student, test_data) for student in students]
89+
print(accuracies)
90+
print(f"Teacher accuracy had been {100*accuracy:.2f}%")

0 commit comments

Comments
 (0)