Skip to content

Commit 3eda592

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent d45f819 commit 3eda592

File tree

1 file changed

+58
-46
lines changed

1 file changed

+58
-46
lines changed
Lines changed: 58 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""
2-
This program is a MNIST classifier using AlexNet. It accepts three parameters provided as a command line input.
3-
The first two inputs are two digits between 0-9 which are used to train and test the classifier and the third
2+
This program is a MNIST classifier using AlexNet. It accepts three parameters provided as a command line input.
3+
The first two inputs are two digits between 0-9 which are used to train and test the classifier and the third
44
parameter controls the number of training epochs.
55
Syntax: python program.py <number> <number> <number>
66
@@ -9,7 +9,6 @@
99
1010
"""
1111

12-
1312
import sys
1413
import torch
1514
import torch.nn as nn
@@ -20,7 +19,7 @@
2019
import torch.optim as optim
2120

2221

23-
class AlexNet(nn.Module):
22+
class AlexNet(nn.Module):
2423
def __init__(self, num=10):
2524
super(AlexNet, self).__init__()
2625
self.feature = nn.Sequential(
@@ -36,20 +35,20 @@ def __init__(self, num=10):
3635
nn.ReLU(inplace=True),
3736
nn.Conv2d(64, 32, kernel_size=3, padding=1),
3837
nn.ReLU(inplace=True),
39-
nn.MaxPool2d(kernel_size=2, stride=1)
38+
nn.MaxPool2d(kernel_size=2, stride=1),
4039
)
4140

4241
self.classifier = nn.Sequential(
4342
# Define classifier here...
4443
nn.Dropout(),
45-
nn.Linear(32*12*12, 2048),
44+
nn.Linear(32 * 12 * 12, 2048),
4645
nn.ReLU(inplace=True),
4746
nn.Dropout(),
4847
nn.Linear(2048, 1024),
4948
nn.ReLU(inplace=True),
50-
nn.Linear(1024, 10)
49+
nn.Linear(1024, 10),
5150
)
52-
51+
5352
def forward(self, x):
5453
# define forward network 'x' that combines feature extractor and classifier
5554
x = self.feature(x)
@@ -63,26 +62,27 @@ def load_subset(full_train_set, full_test_set, label_one, label_two):
6362
train_set = []
6463
data_lim = 20000
6564
for data in full_train_set:
66-
if data_lim>0:
67-
data_lim-=1
68-
if data[1]==label_one or data[1]==label_two:
65+
if data_lim > 0:
66+
data_lim -= 1
67+
if data[1] == label_one or data[1] == label_two:
6968
train_set.append(data)
7069
else:
7170
break
7271

7372
test_set = []
7473
data_lim = 1000
7574
for data in full_test_set:
76-
if data_lim>0:
77-
data_lim-=1
78-
if data[1]==label_one or data[1]==label_two:
75+
if data_lim > 0:
76+
data_lim -= 1
77+
if data[1] == label_one or data[1] == label_two:
7978
test_set.append(data)
8079
else:
8180
break
8281

8382
return train_set, test_set
8483

85-
def train(model,optimizer,train_loader,epoch):
84+
85+
def train(model, optimizer, train_loader, epoch):
8686
model.train()
8787
for batch_idx, (data, target) in enumerate(train_loader):
8888
if torch.cuda.is_available():
@@ -94,7 +94,8 @@ def train(model,optimizer,train_loader,epoch):
9494
loss.backward()
9595
optimizer.step()
9696

97-
def test(model,test_loader):
97+
98+
def test(model, test_loader):
9899
model.eval()
99100
test_loss = 0
100101
correct = 0
@@ -104,65 +105,76 @@ def test(model,test_loader):
104105
with torch.no_grad():
105106
data, target = Variable(data), Variable(target)
106107
output = model(data)
107-
test_loss += F.cross_entropy(output, target, reduction='sum').item()#size_average=False
108-
pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
108+
test_loss += F.cross_entropy(
109+
output, target, reduction="sum"
110+
).item() # size_average=False
111+
pred = output.data.max(1, keepdim=True)[
112+
1
113+
] # get the index of the max log-probability
109114
correct += pred.eq(target.data.view_as(pred)).long().cpu().sum()
110-
115+
111116
test_loss /= len(test_loader.dataset)
112-
acc=100. * float(correct.to(torch.device('cpu')).numpy())
113-
test_accuracy = (acc / len(test_loader.dataset))
117+
acc = 100.0 * float(correct.to(torch.device("cpu")).numpy())
118+
test_accuracy = acc / len(test_loader.dataset)
114119
return test_accuracy
115120

116121

117122
""" Start to call """
118123

119-
if __name__ == '__main__':
120-
124+
if __name__ == "__main__":
121125
if len(sys.argv) == 3:
122126
print("Usage: python assignment.py <number> <number>")
123127
sys.exit(1)
124128

125129
input_data_one = sys.argv[1].strip()
126130
input_data_two = sys.argv[2].strip()
127131
epochs = sys.argv[3].strip()
128-
132+
129133
""" Call to function that will perform the computation. """
130134
if input_data_one.isdigit() and input_data_two.isdigit() and epochs.isdigit():
131-
132135
label_one = int(input_data_one)
133136
label_two = int(input_data_two)
134137
epochs = int(epochs)
135-
136-
if label_one!=label_two and 0<=label_one<=9 and 0<=label_two<=9:
138+
139+
if label_one != label_two and 0 <= label_one <= 9 and 0 <= label_two <= 9:
137140
torch.manual_seed(42)
138141
# Load MNIST dataset
139-
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
140-
full_train_set = dset.MNIST(root='./data', train=True, transform=trans, download=True)
141-
full_test_set = dset.MNIST(root='./data', train=False, transform=trans)
142+
trans = transforms.Compose(
143+
[transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]
144+
)
145+
full_train_set = dset.MNIST(
146+
root="./data", train=True, transform=trans, download=True
147+
)
148+
full_test_set = dset.MNIST(root="./data", train=False, transform=trans)
142149
batch_size = 16
143150
# Get final train and test sets
144-
train_set, test_set = load_subset(full_train_set,full_test_set,label_one,label_two)
145-
146-
train_loader = torch.utils.data.DataLoader(dataset=train_set,batch_size=batch_size,shuffle=False)
147-
test_loader = torch.utils.data.DataLoader(dataset=test_set,batch_size=batch_size,shuffle=False)
151+
train_set, test_set = load_subset(
152+
full_train_set, full_test_set, label_one, label_two
153+
)
154+
155+
train_loader = torch.utils.data.DataLoader(
156+
dataset=train_set, batch_size=batch_size, shuffle=False
157+
)
158+
test_loader = torch.utils.data.DataLoader(
159+
dataset=test_set, batch_size=batch_size, shuffle=False
160+
)
148161

149162
model = AlexNet()
150163
if torch.cuda.is_available():
151164
model.cuda()
152-
165+
153166
optimizer = optim.SGD(model.parameters(), lr=0.01)
154-
155-
for epoch in range(1, epochs+1):
156-
train(model,optimizer,train_loader,epoch)
157-
accuracy = test(model,test_loader)
158-
159-
print(round(accuracy,2))
160-
161-
167+
168+
for epoch in range(1, epochs + 1):
169+
train(model, optimizer, train_loader, epoch)
170+
accuracy = test(model, test_loader)
171+
172+
print(round(accuracy, 2))
173+
162174
else:
163-
print("Invalid input")
175+
print("Invalid input")
164176
else:
165177
print("Invalid input")
166-
167-
178+
179+
168180
""" End to call """

0 commit comments

Comments
 (0)