Skip to content

Commit 9b4fd09

Browse files
authored
Merge pull request #9 from mfl28/recurrent-neural-network
Add Recurrent Neural Network tutorial
2 parents 1c798db + 903efc6 commit 9b4fd09

File tree

6 files changed

+199
-1
lines changed

6 files changed

+199
-1
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ add_subdirectory("tutorials/basics/logistic_regression")
2020
add_subdirectory("tutorials/basics/pytorch_basics")
2121
add_subdirectory("tutorials/intermediate/convolutional_neural_network")
2222
add_subdirectory("tutorials/intermediate/deep_residual_network")
23+
add_subdirectory("tutorials/intermediate/recurrent_neural_network")
2324

2425
# The following code block is suggested to be used on Windows.
2526
# According to https://github.com/pytorch/pytorch/issues/25457,

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ $ ./scripts.sh build
3131
#### 2. Intermediate
3232
* [Convolutional Neural Network](https://github.com/prabhuomkar/pytorch-cpp/tree/master/tutorials/intermediate/convolutional_neural_network/src/main.cpp)
3333
* [Deep Residual Network](https://github.com/prabhuomkar/pytorch-cpp/tree/master/tutorials/intermediate/deep_residual_network/src/main.cpp)
34-
* [Recurrent Neural Network]()
34+
* [Recurrent Neural Network](https://github.com/prabhuomkar/pytorch-cpp/tree/master/tutorials/intermediate/recurrent_neural_network/src/main.cpp)
3535
* [Bidirectional Recurrent Neural Network]()
3636
* [Language Model (RNN-LM)]()
3737

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
2+
3+
project(recurrent-neural-network VERSION 1.0.0 LANGUAGES CXX)
4+
5+
# Files
6+
set(SOURCES src/main.cpp
7+
src/rnn.cpp
8+
)
9+
10+
set(HEADERS include/rnn.h
11+
)
12+
13+
set(EXECUTABLE_NAME recurrent-neural-network)
14+
15+
16+
add_executable(${EXECUTABLE_NAME} ${SOURCES} ${HEADERS})
17+
target_include_directories(${EXECUTABLE_NAME} PRIVATE include)
18+
19+
target_link_libraries(${EXECUTABLE_NAME} "${TORCH_LIBRARIES}")
20+
21+
set_target_properties(${EXECUTABLE_NAME} PROPERTIES
22+
CXX_STANDARD 11
23+
CXX_STANDARD_REQUIRED YES
24+
)
25+
26+
# The following code block is suggested to be used on Windows.
27+
# According to https://github.com/pytorch/pytorch/issues/25457,
28+
# the DLLs need to be copied to avoid memory errors.
29+
# See https://pytorch.org/cppdocs/installing.html.
30+
if (MSVC)
31+
file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll")
32+
add_custom_command(TARGET ${EXECUTABLE_NAME}
33+
POST_BUILD
34+
COMMAND ${CMAKE_COMMAND} -E copy_if_different
35+
${TORCH_DLLS}
36+
$<TARGET_FILE_DIR:${EXECUTABLE_NAME}>)
37+
endif (MSVC)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// Copyright 2019 Markus Fleischhacker
2+
#pragma once
3+
4+
#include <torch/torch.h>
5+
6+
class RNNImpl : public torch::nn::Module {
7+
public:
8+
RNNImpl(int64_t input_size, int64_t hidden_size, int64_t num_layers, int64_t num_classes);
9+
torch::Tensor forward(torch::Tensor x);
10+
11+
private:
12+
torch::nn::LSTM lstm;
13+
torch::nn::Linear fc;
14+
};
15+
16+
TORCH_MODULE(RNN);
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
// Copyright 2019 Markus Fleischhacker
2+
#include <torch/torch.h>
3+
#include <iostream>
4+
#include <iomanip>
5+
#include "rnn.h"
6+
7+
int main() {
8+
std::cout << "Recurrent Neural Network\n\n";
9+
10+
// Device
11+
torch::Device device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU);
12+
13+
// Hyper parameters
14+
const int64_t sequence_length = 28;
15+
const int64_t input_size = 28;
16+
const int64_t hidden_size = 128;
17+
const int64_t num_layers = 2;
18+
const int64_t num_classes = 10;
19+
const int64_t batch_size = 100;
20+
const int64_t num_epochs = 2;
21+
const double learning_rate = 0.01;
22+
23+
const std::string MNIST_data_path = "../../../../tutorials/intermediate/recurrent_neural_network/data/";
24+
25+
// MNIST dataset
26+
auto train_dataset = torch::data::datasets::MNIST(MNIST_data_path)
27+
.map(torch::data::transforms::Normalize<>(0.1307, 0.3081))
28+
.map(torch::data::transforms::Stack<>());
29+
30+
// Number of samples in the training set
31+
auto num_train_samples = train_dataset.size().value();
32+
33+
auto test_dataset = torch::data::datasets::MNIST(MNIST_data_path, torch::data::datasets::MNIST::Mode::kTest)
34+
.map(torch::data::transforms::Normalize<>(0.1307, 0.3081))
35+
.map(torch::data::transforms::Stack<>());
36+
37+
// Number of samples in the testset
38+
auto num_test_samples = test_dataset.size().value();
39+
40+
// Data loader
41+
auto train_loader = torch::data::make_data_loader<torch::data::samplers::RandomSampler>(
42+
std::move(train_dataset), batch_size);
43+
auto test_loader = torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(
44+
std::move(test_dataset), batch_size);
45+
46+
// Model
47+
RNN model(input_size, hidden_size, num_layers, num_classes);
48+
model->to(device);
49+
50+
// Optimizer
51+
auto optimizer = torch::optim::Adam(model->parameters(), torch::optim::AdamOptions(learning_rate));
52+
53+
// Set floating point output precision
54+
std::cout << std::fixed << std::setprecision(4);
55+
56+
std::cout << "Training...\n";
57+
58+
// Train the model
59+
for (size_t epoch = 0; epoch != num_epochs; ++epoch) {
60+
// Initialize running metrics
61+
float running_loss = 0.0;
62+
size_t num_correct = 0;
63+
64+
for (auto& batch : *train_loader) {
65+
// Transfer images and target labels to device
66+
auto data = batch.data.view({-1, sequence_length, input_size}).to(device);
67+
auto target = batch.target.to(device);
68+
69+
// Forward pass
70+
auto output = model->forward(data);
71+
72+
// Calculate loss
73+
auto loss = torch::nll_loss(output, target);
74+
// Update running loss
75+
running_loss += loss.item().toFloat() * data.size(0);
76+
77+
// Calculate prediction
78+
auto prediction = output.argmax(1);
79+
80+
// Update number of correctly classified samples
81+
num_correct += prediction.eq(target).sum().item().toLong();
82+
83+
// Backward pass and optimize
84+
optimizer.zero_grad();
85+
loss.backward();
86+
optimizer.step();
87+
}
88+
89+
auto sample_mean_loss = running_loss / num_train_samples;
90+
auto accuracy = static_cast<float>(num_correct) / num_train_samples;
91+
92+
std::cout << "Epoch [" << (epoch + 1) << "/" << num_epochs << "], Trainset - Loss: "
93+
<< sample_mean_loss << ", Accuracy: " << accuracy << '\n';
94+
}
95+
96+
std::cout << "Training finished!\n\n";
97+
std::cout << "Testing...\n";
98+
99+
// Test the model
100+
model->eval();
101+
torch::NoGradGuard no_grad;
102+
103+
float running_loss = 0.0;
104+
size_t num_correct = 0;
105+
106+
for (const auto& batch : *test_loader) {
107+
auto data = batch.data.view({-1, sequence_length, input_size}).to(device);
108+
auto target = batch.target.to(device);
109+
110+
auto output = model->forward(data);
111+
112+
auto loss = torch::nll_loss(output, target);
113+
running_loss += loss.item().toFloat() * data.size(0);
114+
115+
auto prediction = output.argmax(1);
116+
num_correct += prediction.eq(target).sum().item().toLong();
117+
}
118+
119+
std::cout << "Testing finished!\n";
120+
121+
auto test_accuracy = static_cast<float>(num_correct) / num_test_samples;
122+
auto test_sample_mean_loss = running_loss / num_test_samples;
123+
124+
std::cout << "Testset - Loss: " << test_sample_mean_loss << ", Accuracy: " << test_accuracy << '\n';
125+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// Copyright 2019 Markus Fleischhacker
2+
#include "rnn.h"
3+
#include <torch/torch.h>
4+
5+
RNNImpl::RNNImpl(int64_t input_size, int64_t hidden_size, int64_t num_layers, int64_t num_classes)
6+
: lstm(torch::nn::LSTMOptions(input_size, hidden_size).layers(num_layers).batch_first(true)),
7+
fc(hidden_size, num_classes) {
8+
register_module("lstm", lstm);
9+
register_module("fc", fc);
10+
}
11+
12+
torch::Tensor RNNImpl::forward(torch::Tensor x) {
13+
auto out = lstm->forward(x)
14+
.output
15+
.slice(1, -1)
16+
.squeeze(1);
17+
out = fc->forward(out);
18+
return torch::log_softmax(out, 1);
19+
}

0 commit comments

Comments
 (0)