-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
33 lines (25 loc) · 1 KB
/
main.py
File metadata and controls
33 lines (25 loc) · 1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from learning import batch_learning
import utility as utl
from mnist import MNIST
from net import MultilayerNet
def run():
# parametri di default
n_hidden_layers = 2
n_hidden_nodes_per_layer = [35,35]
act_fun_codes = [0,0,1]
error_fun_code = 1
# caricamento dataset
mndata = MNIST('./data/')
X, t = mndata.load_training()
X = utl.get_mnist_data(X)
t = utl.get_mnist_labels(t)
X, t = utl.get_random_dataset(X, t, n_samples = 10000)
X = utl.get_scaled_data(X)
X_train, X_test, t_train, t_test = utl.train_test_split(X, t, test_size = 0.25)
X_train, X_val, t_train, t_val = utl.train_test_split(X_train, t_train, test_size = 0.3334)
net = MultilayerNet(n_hidden_layers= n_hidden_layers, n_hidden_nodes_per_layer= n_hidden_nodes_per_layer,
act_fun_codes= act_fun_codes, error_fun_code= error_fun_code)
net = batch_learning(net, X_train, t_train, X_val, t_val)
y_test = net.sim(X_test)
utl.print_result(y_test,t_test)
run()