Skip to content

Commit 1680eb6

Browse files
authored
Merge pull request #5 from mfl28/convolutional-neural-network
Add Convolutional Neural Network tutorial
2 parents 8cca2e4 + a60d12e commit 1680eb6

File tree

6 files changed

+211
-1
lines changed

6 files changed

+211
-1
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ add_subdirectory("tutorials/basics/feedforward_neural_network")
1818
add_subdirectory("tutorials/basics/linear_regression")
1919
add_subdirectory("tutorials/basics/logistic_regression")
2020
add_subdirectory("tutorials/basics/pytorch_basics")
21+
add_subdirectory("tutorials/intermediate/convolutional_neural_network")
2122

2223
# The following code block is suggested to be used on Windows.
2324
# According to https://github.com/pytorch/pytorch/issues/25457,

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ $ ./scripts.sh build
2929
* [Feedforward Neural Network](https://github.com/prabhuomkar/pytorch-cpp/tree/master/tutorials/basics/feedforward_neural_network/main.cpp)
3030

3131
#### 2. Intermediate
32-
* [Convolutional Neural Network]()
32+
* [Convolutional Neural Network](https://github.com/prabhuomkar/pytorch-cpp/tree/master/tutorials/intermediate/convolutional_neural_network/src/main.cpp)
3333
* [Deep Residual Network]()
3434
* [Recurrent Neural Network]()
3535
* [Bidirectional Recurrent Neural Network]()
@@ -47,4 +47,5 @@ $ ./scripts.sh build
4747

4848
## Authors
4949
- Omkar Prabhu - [prabhuomkar](https://github.com/prabhuomkar)
50+
- Markus Fleischhacker - [mfl28](https://github.com/mfl28)
5051

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
2+
3+
project(convolutional-neural-network VERSION 1.0.0 LANGUAGES CXX)
4+
5+
# Files
6+
set(SOURCES src/main.cpp
7+
src/convnet.cpp
8+
)
9+
10+
set(HEADERS include/convnet.h
11+
)
12+
13+
set(EXECUTABLE_NAME convolutional-neural-network)
14+
15+
16+
add_executable(${EXECUTABLE_NAME} ${SOURCES} ${HEADERS})
17+
target_include_directories(${EXECUTABLE_NAME} PRIVATE include)
18+
19+
target_link_libraries(${EXECUTABLE_NAME} "${TORCH_LIBRARIES}")
20+
21+
set_target_properties(${EXECUTABLE_NAME} PROPERTIES
22+
CXX_STANDARD 11
23+
CXX_STANDARD_REQUIRED YES
24+
)
25+
26+
# The following code block is suggested to be used on Windows.
27+
# According to https://github.com/pytorch/pytorch/issues/25457,
28+
# the DLLs need to be copied to avoid memory errors.
29+
# See https://pytorch.org/cppdocs/installing.html.
30+
if (MSVC)
31+
file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll")
32+
add_custom_command(TARGET ${EXECUTABLE_NAME}
33+
POST_BUILD
34+
COMMAND ${CMAKE_COMMAND} -E copy_if_different
35+
${TORCH_DLLS}
36+
$<TARGET_FILE_DIR:${EXECUTABLE_NAME}>)
37+
endif (MSVC)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// Copyright 2019 Markus Fleischhacker
2+
#pragma once
3+
4+
#include <torch/torch.h>
5+
6+
class ConvNetImpl : public torch::nn::Module {
7+
public:
8+
explicit ConvNetImpl(int64_t num_classes = 10);
9+
torch::Tensor forward(torch::Tensor x);
10+
11+
private:
12+
torch::nn::Sequential layer1{
13+
torch::nn::Conv2d(torch::nn::Conv2dOptions(1, 16, 5).stride(1).padding(2)),
14+
torch::nn::BatchNorm(16),
15+
torch::nn::Functional(torch::relu),
16+
torch::nn::MaxPool2d(torch::nn::MaxPool2dOptions(2).stride(2))
17+
};
18+
19+
torch::nn::Sequential layer2{
20+
torch::nn::Conv2d(torch::nn::Conv2dOptions(16, 32, 5).stride(1).padding(2)),
21+
torch::nn::BatchNorm(32),
22+
torch::nn::Functional(torch::relu),
23+
torch::nn::MaxPool2d(torch::nn::MaxPool2dOptions(2).stride(2))
24+
};
25+
26+
torch::nn::Linear fc;
27+
};
28+
29+
TORCH_MODULE(ConvNet);
30+
31+
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// Copyright 2019 Markus Fleischhacker
2+
#include "convnet.h"
3+
#include <torch/torch.h>
4+
5+
ConvNetImpl::ConvNetImpl(int64_t num_classes)
6+
: fc(7 * 7 * 32, num_classes) {
7+
register_module("layer1", layer1);
8+
register_module("layer2", layer2);
9+
register_module("fc", fc);
10+
}
11+
12+
torch::Tensor ConvNetImpl::forward(torch::Tensor x) {
13+
x = layer1->forward(x);
14+
x = layer2->forward(x);
15+
x = x.view({-1, 7 * 7 * 32});
16+
x = fc->forward(x);
17+
return torch::log_softmax(x, 1);
18+
}
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
// Copyright 2019 Markus Fleischhacker
2+
#include <torch/torch.h>
3+
#include <iostream>
4+
#include <iomanip>
5+
#include "convnet.h"
6+
7+
int main() {
8+
std::cout << "Convolutional Neural Network\n\n";
9+
10+
// Device
11+
torch::Device device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU);
12+
13+
// Hyper parameters
14+
const int64_t num_epochs = 5;
15+
const int64_t num_classes = 10;
16+
const int64_t batch_size = 100;
17+
const double learning_rate = 0.001;
18+
19+
const std::string MNIST_data_path = "../../../../tutorials/intermediate/convolutional_neural_network/data/";
20+
21+
// MNIST dataset
22+
auto train_dataset = torch::data::datasets::MNIST(MNIST_data_path)
23+
.map(torch::data::transforms::Normalize<>(0.1307, 0.3081))
24+
.map(torch::data::transforms::Stack<>());
25+
26+
// Number of samples in the training set
27+
auto num_train_samples = train_dataset.size().value();
28+
29+
auto test_dataset = torch::data::datasets::MNIST(MNIST_data_path, torch::data::datasets::MNIST::Mode::kTest)
30+
.map(torch::data::transforms::Normalize<>(0.1307, 0.3081))
31+
.map(torch::data::transforms::Stack<>());
32+
33+
// Number of samples in the testset
34+
auto num_test_samples = test_dataset.size().value();
35+
36+
// Data loader
37+
auto train_loader = torch::data::make_data_loader<torch::data::samplers::RandomSampler>(
38+
std::move(train_dataset), batch_size);
39+
auto test_loader = torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(
40+
std::move(test_dataset), batch_size);
41+
42+
// Model
43+
ConvNet model(num_classes);
44+
model->to(device);
45+
46+
// Optimizer
47+
auto optimizer = torch::optim::Adam(model->parameters(), torch::optim::AdamOptions(learning_rate));
48+
49+
// Set floating point output precision
50+
std::cout << std::fixed << std::setprecision(4);
51+
52+
std::cout << "Training...\n";
53+
54+
// Train the model
55+
for (size_t epoch = 0; epoch != num_epochs; ++epoch) {
56+
// Initialize running metrics
57+
float running_loss = 0.0;
58+
size_t num_correct = 0;
59+
60+
for (auto& batch : *train_loader) {
61+
// Transfer images and target labels to device
62+
auto data = batch.data.to(device);
63+
auto target = batch.target.to(device);
64+
65+
// Forward pass
66+
auto output = model->forward(data);
67+
68+
// Calculate loss
69+
auto loss = torch::nll_loss(output, target);
70+
71+
// Update running loss
72+
running_loss += loss.item().toFloat() * data.size(0);
73+
74+
// Calculate prediction
75+
auto prediction = output.argmax(1);
76+
77+
// Update number of correctly classified samples
78+
num_correct += prediction.eq(target).sum().item().toLong();
79+
80+
// Backward pass and optimize
81+
optimizer.zero_grad();
82+
loss.backward();
83+
optimizer.step();
84+
}
85+
86+
auto sample_mean_loss = running_loss / num_train_samples;
87+
auto accuracy = static_cast<float>(num_correct) / num_train_samples;
88+
89+
std::cout << "Epoch [" << (epoch + 1) << "/" << num_epochs << "], Trainset - Loss: "
90+
<< sample_mean_loss << ", Accuracy: " << accuracy << '\n';
91+
}
92+
93+
std::cout << "Training finished!\n\n";
94+
std::cout << "Testing...\n";
95+
96+
// Test the model
97+
model->eval();
98+
torch::NoGradGuard no_grad;
99+
100+
float running_loss = 0.0;
101+
size_t num_correct = 0;
102+
103+
for (const auto& batch : *test_loader) {
104+
auto data = batch.data.to(device);
105+
auto target = batch.target.to(device);
106+
107+
auto output = model->forward(data);
108+
109+
auto loss = torch::nll_loss(output, target);
110+
running_loss += loss.item().toFloat() * data.size(0);
111+
112+
auto prediction = output.argmax(1);
113+
num_correct += prediction.eq(target).sum().item().toLong();
114+
}
115+
116+
std::cout << "Testing finished!\n";
117+
118+
auto test_accuracy = static_cast<float>(num_correct) / num_test_samples;
119+
auto test_sample_mean_loss = running_loss / num_test_samples;
120+
121+
std::cout << "Testset - Loss: " << test_sample_mean_loss << ", Accuracy: " << test_accuracy << '\n';
122+
}

0 commit comments

Comments
 (0)