44
55import csv
66import numpy as np
7- from nn import NeuralNetwork , Layer , LeakyReLU , CrossEntropyLoss
7+ from nn import NeuralNetwork , Layer , LeakyReLU , CrossEntropyLoss , MSELoss
88
99TRAIN_FILE = pathlib .Path (__file__ ).parent / "mnistdata/mnist_train.csv"
1010TEST_FILE = pathlib .Path (__file__ ).parent / "mnistdata/mnist_test.csv"
@@ -45,14 +45,32 @@ def train(net, train_data):
4545
4646 net .train (to_col (train_row [1 :])/ 255 , train_row [0 ])
4747
48+ def train_students (teacher , students , train_data ):
49+ for i , train_row in enumerate (train_data ):
50+ if not i % 1000 :
51+ print (i )
52+
53+ x = to_col (train_row [1 :])/ 255
54+ out = teacher .forward_pass (x )
55+ for student in students :
56+ student .train (x , out )
57+
4858
4959if __name__ == "__main__" :
5060 layers = [
5161 Layer (784 , 16 , LeakyReLU ()),
5262 Layer (16 , 16 , LeakyReLU ()),
5363 Layer (16 , 10 , LeakyReLU ()),
5464 ]
55- teacher = NeuralNetwork (layers , CrossEntropyLoss (), 0.01 )
65+ teacher = NeuralNetwork (layers , CrossEntropyLoss (), 0.03 )
66+ students = [
67+ NeuralNetwork ([Layer (784 , 10 , LeakyReLU ())], MSELoss (), 0.001 ),
68+ NeuralNetwork ([Layer (784 , 10 , LeakyReLU ())], MSELoss (), 0.003 ),
69+ NeuralNetwork ([Layer (784 , 10 , LeakyReLU ())], MSELoss (), 0.01 ),
70+ NeuralNetwork ([Layer (784 , 10 , LeakyReLU ())], MSELoss (), 0.03 ),
71+ NeuralNetwork ([Layer (784 , 10 , LeakyReLU ())], MSELoss (), 0.1 ),
72+ NeuralNetwork ([Layer (784 , 10 , LeakyReLU ())], MSELoss (), 0.3 ),
73+ ]
5674
5775 test_data = load_data (TEST_FILE , delimiter = "," , dtype = int )
5876 accuracy = test (teacher , test_data )
@@ -63,3 +81,10 @@ def train(net, train_data):
6381
6482 accuracy = test (teacher , test_data )
6583 print (f"Accuracy is { 100 * accuracy :.2f} %" )
84+
85+ print ("Training students." )
86+ train_students (teacher , students , train_data )
87+ print ("Testing students." )
88+ accuracies = [100 * test (student , test_data ) for student in students ]
89+ print (accuracies )
90+ print (f"Teacher accuracy had been { 100 * accuracy :.2f} %" )
0 commit comments