Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 176 additions & 0 deletions computer_vision/mnist_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
"""
This program is a MNIST classifier using AlexNet.

For example, to train and test AlexNet with 1 and 2 MNIST samples with 4 training epochs.

Check failure on line 4 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E501)

computer_vision/mnist_classifier.py:4:89: E501 Line too long (89 > 88)

Check failure on line 4 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E501)

computer_vision/mnist_classifier.py:4:89: E501 Line too long (89 > 88)

Check failure on line 4 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E501)

computer_vision/mnist_classifier.py:4:89: E501 Line too long (89 > 88)

Check failure on line 4 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E501)

computer_vision/mnist_classifier.py:4:89: E501 Line too long (89 > 88)

Check failure on line 4 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E501)

computer_vision/mnist_classifier.py:4:89: E501 Line too long (89 > 88)
The command line input should be:
python program.py 1 2 4

"""

import sys
import torch
import torch.nn as n
import torchvision.datasets as dset
import torchvision.transforms

Check failure on line 14 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

computer_vision/mnist_classifier.py:14:8: F401 `torchvision.transforms` imported but unused

Check failure on line 14 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

computer_vision/mnist_classifier.py:14:8: F401 `torchvision.transforms` imported but unused

Check failure on line 14 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

computer_vision/mnist_classifier.py:14:8: F401 `torchvision.transforms` imported but unused

Check failure on line 14 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

computer_vision/mnist_classifier.py:14:8: F401 `torchvision.transforms` imported but unused

Check failure on line 14 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

computer_vision/mnist_classifier.py:14:8: F401 `torchvision.transforms` imported but unused
from torch.autograd import Variable
import torch.nn.functional as f
import torch.optim


class AlexNet(n.Module):

Check failure on line 20 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

computer_vision/mnist_classifier.py:10:1: I001 Import block is un-sorted or un-formatted

Check failure on line 20 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

computer_vision/mnist_classifier.py:10:1: I001 Import block is un-sorted or un-formatted

Check failure on line 20 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

computer_vision/mnist_classifier.py:10:1: I001 Import block is un-sorted or un-formatted

Check failure on line 20 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

computer_vision/mnist_classifier.py:10:1: I001 Import block is un-sorted or un-formatted

Check failure on line 20 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

computer_vision/mnist_classifier.py:10:1: I001 Import block is un-sorted or un-formatted
def __init__(self, num):

Check failure on line 21 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ARG002)

computer_vision/mnist_classifier.py:21:24: ARG002 Unused method argument: `num`

Check failure on line 21 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ARG002)

computer_vision/mnist_classifier.py:21:24: ARG002 Unused method argument: `num`

Check failure on line 21 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ARG002)

computer_vision/mnist_classifier.py:21:24: ARG002 Unused method argument: `num`

Check failure on line 21 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ARG002)

computer_vision/mnist_classifier.py:21:24: ARG002 Unused method argument: `num`

Check failure on line 21 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ARG002)

computer_vision/mnist_classifier.py:21:24: ARG002 Unused method argument: `num`

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please provide return type hint for the function: __init__. If the function does not return a value, please provide the type hint as: def function() -> None:

Please provide type hint for the parameter: num

super().__init__()
self.feature = n.Sequential(
# Define feature extractor here...
n.Conv2d(1, 32, kernel_size=5, stride=1, padding=1),
n.ReLU(inplace=True),
n.Conv2d(32, 64, kernel_size=3, padding=1),
n.ReLU(inplace=True),
n.MaxPool2d(kernel_size=2, stride=2),
n.Conv2d(64, 96, kernel_size=3, padding=1),
n.ReLU(inplace=True),
n.Conv2d(96, 64, kernel_size=3, padding=1),
n.ReLU(inplace=True),
n.Conv2d(64, 32, kernel_size=3, padding=1),
n.ReLU(inplace=True),
n.MaxPool2d(kernel_size=2, stride=1),
)

self.classifier = n.Sequential(
# Define classifier here...
n.Dropout(),
n.Linear(32 * 12 * 12, 2048),
n.ReLU(inplace=True),
n.Dropout(),
n.Linear(2048, 1024),
n.ReLU(inplace=True),
n.Linear(1024, 10),
)

def forward(self, x):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please provide return type hint for the function: forward. If the function does not return a value, please provide the type hint as: def function() -> None:

As there is no test file in this pull request nor any test function or class in the file computer_vision/mnist_classifier.py, please provide doctest for the function forward

Please provide type hint for the parameter: x

Please provide descriptive name for the parameter: x

# define forward network 'x' that combines feature extractor and classifier
x = self.feature(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x


def load_subset(full_train_set, full_test_set, label_one, label_two):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please provide return type hint for the function: load_subset. If the function does not return a value, please provide the type hint as: def function() -> None:

As there is no test file in this pull request nor any test function or class in the file computer_vision/mnist_classifier.py, please provide doctest for the function load_subset

Please provide type hint for the parameter: full_train_set

Please provide type hint for the parameter: full_test_set

Please provide type hint for the parameter: label_one

Please provide type hint for the parameter: label_two

# Sample the correct train labels
train_set = []
data_lim = 20000
for data in full_train_set:
if data_lim > 0:
data_lim -= 1
if data[1] == label_one or data[1] == label_two:
train_set.append(data)
else:
break

test_set = []
data_lim = 1000
for data in full_test_set:
if data_lim > 0:
data_lim -= 1
if data[1] == label_one or data[1] == label_two:
test_set.append(data)
else:
break

return train_set, test_set


def train(model, optimizer, train_loader, epoch):

Check failure on line 83 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ARG001)

computer_vision/mnist_classifier.py:83:43: ARG001 Unused function argument: `epoch`

Check failure on line 83 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ARG001)

computer_vision/mnist_classifier.py:83:43: ARG001 Unused function argument: `epoch`

Check failure on line 83 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ARG001)

computer_vision/mnist_classifier.py:83:43: ARG001 Unused function argument: `epoch`

Check failure on line 83 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ARG001)

computer_vision/mnist_classifier.py:83:43: ARG001 Unused function argument: `epoch`

Check failure on line 83 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ARG001)

computer_vision/mnist_classifier.py:83:43: ARG001 Unused function argument: `epoch`

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please provide return type hint for the function: train. If the function does not return a value, please provide the type hint as: def function() -> None:

As there is no test file in this pull request nor any test function or class in the file computer_vision/mnist_classifier.py, please provide doctest for the function train

Please provide type hint for the parameter: model

Please provide type hint for the parameter: optimizer

Please provide type hint for the parameter: train_loader

Please provide type hint for the parameter: epoch

model.train()
for data, target in enumerate(train_loader):
if torch.cuda.is_available():
data, target = data.cuda(), target.cuda()
data, target = Variable(data), Variable(target)
optimizer.zero_grad()
output = model(data)
loss = f.cross_entropy(output, target)
loss.backward()
optimizer.step()


def test(model, test_loader):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please provide return type hint for the function: test. If the function does not return a value, please provide the type hint as: def function() -> None:

As there is no test file in this pull request nor any test function or class in the file computer_vision/mnist_classifier.py, please provide doctest for the function test

Please provide type hint for the parameter: model

Please provide type hint for the parameter: test_loader

model.eval()
test_loss = 0
correct = 0
for data, target in test_loader:
if torch.cuda.is_available():
data, target = data.cuda(), target.cuda()
with torch.no_grad():
data, target = Variable(data), Variable(target)
output = model(data)
test_loss += F.cross_entropy(output, target, reduction="sum").item()

Check failure on line 106 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

computer_vision/mnist_classifier.py:106:22: F821 Undefined name `F`

Check failure on line 106 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

computer_vision/mnist_classifier.py:106:22: F821 Undefined name `F`

Check failure on line 106 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

computer_vision/mnist_classifier.py:106:22: F821 Undefined name `F`

Check failure on line 106 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

computer_vision/mnist_classifier.py:106:22: F821 Undefined name `F`

Check failure on line 106 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

computer_vision/mnist_classifier.py:106:22: F821 Undefined name `F`
# size_average=False
pred = output.data.max(1, keepdim=True)[1]
# get the index of the max log-probability
correct += pred.eq(target.data.view_as(pred)).long().cpu().sum()

test_loss /= len(test_loader.dataset)
acc = 100.0 * float(correct.to(torch.device("cpu")).numpy())
test_accuracy = acc / len(test_loader.dataset)
return test_accuracy


""" Start to call """

if __name__ == "__main__":
if len(sys.argv) == 3:
print("Usage: python assignment.py <number> <number>")
sys.exit(1)

input_data_one = sys.argv[1].strip()
input_data_two = sys.argv[2].strip()
epochs = sys.argv[3].strip()

""" Call to function that will perform the computation. """
if input_data_one.isdigit() and input_data_two.isdigit() and epochs.isdigit():
label_one = int(input_data_one)
label_two = int(input_data_two)
epochs = int(epochs)

if label_one != label_two and 0 <= label_one <= 9 and 0 <= label_two <= 9:
torch.manual_seed(42)
# Load MNIST dataset
trans = transforms.Compose(

Check failure on line 138 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

computer_vision/mnist_classifier.py:138:21: F821 Undefined name `transforms`

Check failure on line 138 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

computer_vision/mnist_classifier.py:138:21: F821 Undefined name `transforms`

Check failure on line 138 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

computer_vision/mnist_classifier.py:138:21: F821 Undefined name `transforms`

Check failure on line 138 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

computer_vision/mnist_classifier.py:138:21: F821 Undefined name `transforms`

Check failure on line 138 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

computer_vision/mnist_classifier.py:138:21: F821 Undefined name `transforms`
[transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]

Check failure on line 139 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

computer_vision/mnist_classifier.py:139:18: F821 Undefined name `transforms`

Check failure on line 139 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

computer_vision/mnist_classifier.py:139:41: F821 Undefined name `transforms`

Check failure on line 139 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

computer_vision/mnist_classifier.py:139:18: F821 Undefined name `transforms`

Check failure on line 139 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

computer_vision/mnist_classifier.py:139:41: F821 Undefined name `transforms`

Check failure on line 139 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

computer_vision/mnist_classifier.py:139:18: F821 Undefined name `transforms`

Check failure on line 139 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

computer_vision/mnist_classifier.py:139:41: F821 Undefined name `transforms`

Check failure on line 139 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

computer_vision/mnist_classifier.py:139:18: F821 Undefined name `transforms`

Check failure on line 139 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

computer_vision/mnist_classifier.py:139:41: F821 Undefined name `transforms`

Check failure on line 139 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

computer_vision/mnist_classifier.py:139:18: F821 Undefined name `transforms`

Check failure on line 139 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

computer_vision/mnist_classifier.py:139:41: F821 Undefined name `transforms`
)
full_train_set = dset.MNIST(
root="./data", train=True, transform=trans, download=True
)
full_test_set = dset.MNIST(root="./data", train=False, transform=trans)
batch_size = 16
# Get final train and test sets
train_set, test_set = load_subset(
full_train_set, full_test_set, label_one, label_two
)

train_loader = torch.utils.data.DataLoader(
dataset=train_set, batch_size=batch_size, shuffle=False
)
test_loader = torch.utils.data.DataLoader(
dataset=test_set, batch_size=batch_size, shuffle=False
)

model = AlexNet()
if torch.cuda.is_available():
model.cuda()

optimizer = optim.SGD(model.parameters(), lr=0.01)

Check failure on line 162 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

computer_vision/mnist_classifier.py:162:25: F821 Undefined name `optim`

Check failure on line 162 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

computer_vision/mnist_classifier.py:162:25: F821 Undefined name `optim`

Check failure on line 162 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

computer_vision/mnist_classifier.py:162:25: F821 Undefined name `optim`

Check failure on line 162 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

computer_vision/mnist_classifier.py:162:25: F821 Undefined name `optim`

Check failure on line 162 in computer_vision/mnist_classifier.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

computer_vision/mnist_classifier.py:162:25: F821 Undefined name `optim`

for epoch in range(1, epochs + 1):
train(model, optimizer, train_loader, epoch)
accuracy = test(model, test_loader)

print(round(accuracy, 2))

else:
print("Invalid input")
else:
print("Invalid input")


""" End to call """
Loading