Skip to content

Commit 2d39094

Browse files
committed
Implement MNIST model with custom MLP
1 parent 5185521 commit 2d39094

File tree

1 file changed

+88
-8
lines changed

1 file changed

+88
-8
lines changed

hw1/mnist.py

Lines changed: 88 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,48 @@
11
"""Problem 3 - Training on MNIST"""
22
import numpy as np
3+
import pdb
4+
5+
from mytorch.optim.sgd import SGD
6+
from mytorch.nn.activations import ReLU
7+
from mytorch.nn.loss import CrossEntropyLoss
8+
from mytorch.nn.linear import Linear
9+
from mytorch.nn.sequential import Sequential
10+
from mytorch.tensor import Tensor
311

412
# TODO: Import any mytorch packages you need (XELoss, SGD, etc)
513

614
# NOTE: Batch size pre-set to 100. Shouldn't need to change.
715
BATCH_SIZE = 100
816

17+
918
def mnist(train_x, train_y, val_x, val_y):
1019
"""Problem 3.1: Initialize objects and start training
1120
You won't need to call this function yourself.
1221
(Data is provided by autograder)
13-
22+
1423
Args:
15-
train_x (np.array): training data (55000, 784)
16-
train_y (np.array): training labels (55000,)
24+
train_x (np.array): training data (55000, 784)
25+
train_y (np.array): training labels (55000,)
1726
val_x (np.array): validation data (5000, 784)
1827
val_y (np.array): validation labels (5000,)
1928
Returns:
2029
val_accuracies (list(float)): List of accuracies per validation round
2130
(num_epochs,)
2231
"""
2332
# TODO: Initialize an MLP, optimizer, and criterion
33+
# convert training and validation datasets to tensors
34+
# create MLP with provided architecture
35+
# MLP architecture -> Linear(784, 20) -> BatchNorm1d(20) -> ReLU() -> Linear(20, 10)
36+
model = Sequential(Linear(784, 20), ReLU(), Linear(20, 10))
37+
# Set the learning rate of your optimizer to lr=0.1.
38+
optimizer = SGD(model.parameters(), lr=0.1, momentum=0.1)
39+
# Initialize your criterion (CrossEntropyLoss)
40+
criterion = CrossEntropyLoss()
2441

2542
# TODO: Call training routine (make sure to write it below)
26-
val_accuracies = None
43+
val_accuracies = train(
44+
model, optimizer, criterion, train_x, train_y, val_x, val_y, num_epochs=3
45+
)
2746
return val_accuracies
2847

2948

@@ -32,11 +51,51 @@ def train(model, optimizer, criterion, train_x, train_y, val_x, val_y, num_epoch
3251
Returns:
3352
val_accuracies (list): (num_epochs,)
3453
"""
35-
val_accuracies = []
36-
54+
print("Training...")
55+
# model.activate_train_mode()
56+
model.train()
57+
# for each epoch:
58+
for epoch in range(num_epochs):
59+
# shuffle_train_data()
60+
shuffle = True
61+
if shuffle:
62+
dataset_size = np.arange(len(train_x))
63+
permutation = np.random.shuffle(dataset_size)
64+
train_x = np.squeeze(train_x[permutation])
65+
train_y = np.squeeze(train_y[permutation])
66+
67+
# break into bactches
68+
# batches = split_data_into_batches()
69+
70+
BATCH_SIZE = 100
71+
current_batch = 0
72+
# batches = []
73+
# batches = split_data_into_batches(train_x, train_y)
74+
x_batches = np.array_split(train_x, 100)
75+
y_batches = np.array_split(train_y, 100)
76+
batches = zip(x_batches, y_batches)
77+
78+
val_accuracies = []
79+
# for i, (batch_data, batch_labels) in enumerate(batches):
80+
for i, (batch_data, batch_labels) in enumerate(batches):
81+
82+
optimizer.zero_grad() # clear any previous gradients
83+
out = model.forward(Tensor(batch_data))
84+
loss = criterion.forward(out, Tensor(batch_labels))
85+
loss.backward()
86+
optimizer.step() # update weights with new gradients
87+
# loss = CrossEntropyLoss(out, batch_labels)
88+
# loss.backwards()
89+
if BATCH_SIZE % 100 == 0:
90+
accuracy = validate(model=model, val_x=val_x, val_y=val_y)
91+
# store_validation_accuracy(accuracy)
92+
val_accuracies.append(accuracy)
93+
model.train()
94+
3795
# TODO: Implement me! (Pseudocode on writeup)
3896
return val_accuracies
3997

98+
4099
def validate(model, val_x, val_y):
41100
"""Problem 3.3: Validation routine, tests on val data, scores accuracy
42101
Relevant Args:
@@ -45,5 +104,26 @@ def validate(model, val_x, val_y):
45104
Returns:
46105
float: Accuracy = correct / total
47106
"""
48-
#TODO: implement validation based on pseudocode
49-
return 0
107+
# TODO: implement validation based on pseudocode
108+
# model.activate_eval_mode()
109+
model.eval()
110+
# batches = split_data_into_batches()
111+
# batches = split_data_into_batches(features=val_x, labels=val_y)
112+
113+
x_batches = np.array_split(val_x, 100)
114+
y_batches = np.array_split(val_y, 100)
115+
batches = zip(x_batches, y_batches)
116+
num_correct = 0
117+
# for (batch_data, batch_labels) in batches:
118+
for i, (batch_data, batch_labels) in enumerate(batches):
119+
# out = forward_pass(batch_data)
120+
# pdb.set_trace()
121+
out = model.forward(Tensor(batch_data))
122+
123+
# batch_preds = get_idxs_of_largest_values_per_batch(out)
124+
batch_preds = np.argmax(out.data, axis=1)
125+
# print(f"Batch preds are {batch_preds}")
126+
num_correct += batch_preds == batch_labels
127+
accuracy = num_correct.sum() / len(val_x)
128+
129+
return accuracy

0 commit comments

Comments
 (0)