Skip to content

Commit 1c798db

Browse files
authored
Merge pull request #7 from mfl28/deep-residual-network
Add Deep Residual Network tutorial
2 parents 4d369b8 + 9180bbd commit 1c798db

File tree

12 files changed

+596
-4
lines changed

12 files changed

+596
-4
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ add_subdirectory("tutorials/basics/linear_regression")
1919
add_subdirectory("tutorials/basics/logistic_regression")
2020
add_subdirectory("tutorials/basics/pytorch_basics")
2121
add_subdirectory("tutorials/intermediate/convolutional_neural_network")
22+
add_subdirectory("tutorials/intermediate/deep_residual_network")
2223

2324
# The following code block is suggested to be used on Windows.
2425
# According to https://github.com/pytorch/pytorch/issues/25457,

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ $ ./scripts.sh build
3030

3131
#### 2. Intermediate
3232
* [Convolutional Neural Network](https://github.com/prabhuomkar/pytorch-cpp/tree/master/tutorials/intermediate/convolutional_neural_network/src/main.cpp)
33-
* [Deep Residual Network]()
33+
* [Deep Residual Network](https://github.com/prabhuomkar/pytorch-cpp/tree/master/tutorials/intermediate/deep_residual_network/src/main.cpp)
3434
* [Recurrent Neural Network]()
3535
* [Bidirectional Recurrent Neural Network]()
3636
* [Language Model (RNN-LM)]()

scripts.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
#!/bin/bash
22

33
function install() {
4-
wget https://download.pytorch.org/libtorch/nightly/cpu/libtorch-shared-with-deps-latest.zip
5-
unzip libtorch-shared-with-deps-latest.zip
6-
rm -rf libtorch-shared-with-deps-latest.zip
4+
wget https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-1.3.1%2Bcpu.zip
5+
unzip libtorch-shared-with-deps-1.3.1+cpu.zip
6+
rm -rf libtorch-shared-with-deps-1.3.1+cpu.zip
77
}
88

99
function build() {
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
2+
3+
project(deep-residual-network VERSION 1.0.0 LANGUAGES CXX)
4+
5+
# Files
6+
set(SOURCES src/main.cpp
7+
src/cifar10.cpp
8+
src/residual_block.cpp
9+
src/transform.cpp
10+
)
11+
12+
set(HEADERS include/residual_block.h
13+
include/resnet.h
14+
include/cifar10.h
15+
include/transform.h
16+
)
17+
18+
set(EXECUTABLE_NAME deep-residual-network)
19+
20+
21+
add_executable(${EXECUTABLE_NAME} ${SOURCES} ${HEADERS})
22+
target_include_directories(${EXECUTABLE_NAME} PRIVATE include)
23+
24+
target_link_libraries(${EXECUTABLE_NAME} "${TORCH_LIBRARIES}")
25+
26+
set_target_properties(${EXECUTABLE_NAME} PROPERTIES
27+
CXX_STANDARD 11
28+
CXX_STANDARD_REQUIRED YES
29+
)
30+
31+
# The following code block is suggested to be used on Windows.
32+
# According to https://github.com/pytorch/pytorch/issues/25457,
33+
# the DLLs need to be copied to avoid memory errors.
34+
# See https://pytorch.org/cppdocs/installing.html.
35+
if (MSVC)
36+
file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll")
37+
add_custom_command(TARGET ${EXECUTABLE_NAME}
38+
POST_BUILD
39+
COMMAND ${CMAKE_COMMAND} -E copy_if_different
40+
${TORCH_DLLS}
41+
$<TARGET_FILE_DIR:${EXECUTABLE_NAME}>)
42+
endif (MSVC)
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// Copyright 2019 Markus Fleischhacker
2+
#pragma once
3+
4+
#include <torch/data/datasets/base.h>
5+
#include <torch/data/example.h>
6+
#include <torch/types.h>
7+
#include <cstddef>
8+
#include <string>
9+
10+
// CIFAR10 dataset
11+
// based on: https://github.com/pytorch/pytorch/blob/master/torch/csrc/api/include/torch/data/datasets/mnist.h.
12+
class CIFAR10 : public torch::data::datasets::Dataset<CIFAR10> {
13+
public:
14+
// The mode in which the dataset is loaded
15+
enum Mode { kTrain, kTest };
16+
17+
// Loads the CIFAR10 dataset from the `root` path.
18+
//
19+
// The supplied `root` path should contain the *content* of the unzipped
20+
// CIFAR10 dataset (binary version), available from http://www.cs.toronto.edu/~kriz/cifar.html.
21+
explicit CIFAR10(const std::string& root, Mode mode = Mode::kTrain);
22+
23+
// Returns the `Example` at the given `index`.
24+
torch::data::Example<> get(size_t index) override;
25+
26+
// Returns the size of the dataset.
27+
torch::optional<size_t> size() const override;
28+
29+
// Returns true if this is the training subset of CIFAR10.
30+
bool is_train() const noexcept;
31+
32+
// Returns all images stacked into a single tensor.
33+
const torch::Tensor& images() const;
34+
35+
// Returns all targets stacked into a single tensor.
36+
const torch::Tensor& targets() const;
37+
38+
private:
39+
torch::Tensor images_;
40+
torch::Tensor targets_;
41+
Mode mode_;
42+
};
43+
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// Copyright 2019 Markus Fleischhacker
2+
#pragma once
3+
4+
#include <torch/torch.h>
5+
6+
namespace resnet {
7+
class ResidualBlockImpl : public torch::nn::Module {
8+
public:
9+
ResidualBlockImpl(int64_t in_channels, int64_t out_channels, int64_t stride = 1,
10+
torch::nn::Sequential downsample = nullptr);
11+
torch::Tensor forward(torch::Tensor x);
12+
13+
private:
14+
torch::nn::Conv2d conv1;
15+
torch::nn::BatchNorm bn1;
16+
torch::nn::Functional relu;
17+
torch::nn::Conv2d conv2;
18+
torch::nn::BatchNorm bn2;
19+
torch::nn::Sequential downsampler;
20+
};
21+
22+
torch::nn::Conv2d conv3x3(int64_t in_channels, int64_t out_channels, int64_t stride = 1);
23+
24+
TORCH_MODULE(ResidualBlock);
25+
} // namespace resnet
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
// Copyright 2019 Markus Fleischhacker
2+
#pragma once
3+
4+
#include <torch/torch.h>
5+
#include <vector>
6+
#include "residual_block.h"
7+
8+
namespace resnet {
9+
template<typename Block>
10+
class ResNetImpl : public torch::nn::Module {
11+
public:
12+
explicit ResNetImpl(const std::array<int64_t, 3>& layers, int64_t num_classes = 10);
13+
torch::Tensor forward(torch::Tensor x);
14+
15+
private:
16+
int64_t in_channels = 16;
17+
torch::nn::Conv2d conv{conv3x3(3, 16)};
18+
torch::nn::BatchNorm bn{16};
19+
torch::nn::Functional relu{torch::relu};
20+
torch::nn::Sequential layer1;
21+
torch::nn::Sequential layer2;
22+
torch::nn::Sequential layer3;
23+
torch::nn::AvgPool2d avg_pool{8};
24+
torch::nn::Linear fc;
25+
26+
torch::nn::Sequential make_layer(int64_t out_channels, int64_t blocks, int64_t stride = 1);
27+
};
28+
29+
template<typename Block>
30+
ResNetImpl<Block>::ResNetImpl(const std::array<int64_t, 3>& layers, int64_t num_classes) :
31+
layer1(make_layer(16, layers[0])),
32+
layer2(make_layer(32, layers[1], 2)),
33+
layer3(make_layer(64, layers[2], 2)),
34+
fc(64, num_classes) {
35+
register_module("conv", conv);
36+
register_module("bn", bn);
37+
register_module("relu", relu);
38+
register_module("layer1", layer1);
39+
register_module("layer2", layer2);
40+
register_module("layer3", layer3);
41+
register_module("avg_pool", avg_pool);
42+
register_module("fc", fc);
43+
}
44+
45+
template<typename Block>
46+
torch::Tensor ResNetImpl<Block>::forward(torch::Tensor x) {
47+
auto out = conv->forward(x);
48+
out = bn->forward(out);
49+
out = relu->forward(out);
50+
out = layer1->forward(out);
51+
out = layer2->forward(out);
52+
out = layer3->forward(out);
53+
out = avg_pool->forward(out);
54+
out = out.view({out.size(0), -1});
55+
out = fc->forward(out);
56+
57+
return torch::log_softmax(out, 1);
58+
}
59+
60+
template<typename Block>
61+
torch::nn::Sequential ResNetImpl<Block>::make_layer(int64_t out_channels, int64_t blocks, int64_t stride) {
62+
torch::nn::Sequential layers;
63+
torch::nn::Sequential downsample{nullptr};
64+
65+
if (stride != 1 || in_channels != out_channels) {
66+
downsample = torch::nn::Sequential{
67+
conv3x3(in_channels, out_channels, stride),
68+
torch::nn::BatchNorm(out_channels)
69+
};
70+
}
71+
72+
layers->push_back(Block(in_channels, out_channels, stride, downsample));
73+
74+
in_channels = out_channels;
75+
76+
for (int64_t i = 1; i != blocks; ++i) {
77+
layers->push_back(Block(out_channels, out_channels));
78+
}
79+
80+
return layers;
81+
}
82+
83+
// Wrap class into ModuleHolder (a shared_ptr wrapper),
84+
// see https://github.com/pytorch/pytorch/blob/master/torch/csrc/api/include/torch/nn/pimpl.h
85+
template<typename Block = ResidualBlock>
86+
class ResNet : public torch::nn::ModuleHolder<ResNetImpl<Block>> {
87+
public:
88+
using torch::nn::ModuleHolder<ResNetImpl<Block>>::ModuleHolder;
89+
};
90+
} // namespace resnet
91+
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// Copyright 2019 Markus Fleischhacker
2+
#pragma once
3+
4+
#include <torch/torch.h>
5+
#include <random>
6+
7+
namespace transform {
8+
class RandomHorizontalFlip : public torch::data::transforms::TensorTransform<torch::Tensor> {
9+
public:
10+
// Creates a transformation that randomly horizontally flips a tensor.
11+
//
12+
// The parameter `p` determines the probability that a tensor is flipped (default = 0.5).
13+
explicit RandomHorizontalFlip(double p = 0.5);
14+
15+
torch::Tensor operator()(torch::Tensor input) override;
16+
17+
private:
18+
double p_;
19+
};
20+
21+
class ConstantPad : public torch::data::transforms::TensorTransform<torch::Tensor> {
22+
public:
23+
// Creates a transformation that pads a tensor.
24+
//
25+
// `padding` is expected to be a vector of size 4 whose entries correspond to the
26+
// padding of the sides, i.e {left, right, top, bottom}. `value` determines the value
27+
// for the padded pixels.
28+
explicit ConstantPad(const std::vector<int64_t>& padding, torch::Scalar value = 0);
29+
30+
// Creates a transformation that pads a tensor.
31+
//
32+
// The padding will be performed using the size `padding` for all 4 sides.
33+
// `value` determines the value for the padded pixels.
34+
explicit ConstantPad(int64_t padding, torch::Scalar value = 0);
35+
36+
torch::Tensor operator()(torch::Tensor input) override;
37+
38+
private:
39+
std::vector<int64_t> padding_;
40+
torch::Scalar value_;
41+
};
42+
43+
class RandomCrop : public torch::data::transforms::TensorTransform<torch::Tensor> {
44+
public:
45+
// Creates a transformation that randomly crops a tensor.
46+
//
47+
// The parameter `size` is expected to be a vector of size 2
48+
// and determines the output size {height, width}.
49+
explicit RandomCrop(const std::vector<int64_t>& size);
50+
torch::Tensor operator()(torch::Tensor input) override;
51+
52+
private:
53+
std::vector<int64_t> size_;
54+
};
55+
} // namespace transform
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
// Copyright 2019 Markus Fleischhacker
2+
#include "cifar10.h"
3+
4+
namespace {
5+
// CIFAR10 dataset description can be found at https://www.cs.toronto.edu/~kriz/cifar.html.
6+
constexpr uint32_t kTrainSize = 50000;
7+
constexpr uint32_t kTestSize = 10000;
8+
constexpr uint32_t kSizePerBatch = 10000;
9+
constexpr uint32_t kImageRows = 32;
10+
constexpr uint32_t kImageColumns = 32;
11+
constexpr uint32_t kBytesPerRow = 3073;
12+
constexpr uint32_t kBytesPerChannelPerRow = (kBytesPerRow - 1) / 3;
13+
constexpr uint32_t kBytesPerBatchFile = kBytesPerRow * kSizePerBatch;
14+
15+
const std::vector<std::string> kTrainDataBatchFiles = {
16+
"data_batch_1.bin",
17+
"data_batch_2.bin",
18+
"data_batch_3.bin",
19+
"data_batch_4.bin",
20+
"data_batch_5.bin",
21+
};
22+
23+
const std::vector<std::string> kTestDataBatchFiles = {
24+
"test_batch.bin"
25+
};
26+
27+
// Source: https://github.com/pytorch/pytorch/blob/master/torch/csrc/api/src/data/datasets/mnist.cpp.
28+
std::string join_paths(std::string head, const std::string& tail) {
29+
if (head.back() != '/') {
30+
head.push_back('/');
31+
}
32+
head += tail;
33+
return head;
34+
}
35+
// Partially based on https://github.com/pytorch/pytorch/blob/master/torch/csrc/api/src/data/datasets/mnist.cpp.
36+
std::pair<torch::Tensor, torch::Tensor> read_data(const std::string& root, bool train) {
37+
const auto& files = train ? kTrainDataBatchFiles : kTestDataBatchFiles;
38+
const auto num_samples = train ? kTrainSize : kTestSize;
39+
40+
std::vector<char> data_buffer;
41+
data_buffer.reserve(files.size() * kBytesPerBatchFile);
42+
43+
for (const auto& file : files) {
44+
const auto path = join_paths(root, file);
45+
std::ifstream data(path, std::ios::binary);
46+
TORCH_CHECK(data, "Error opening data file at", path);
47+
48+
data_buffer.insert(data_buffer.end(), std::istreambuf_iterator<char>(data), {});
49+
}
50+
51+
TORCH_CHECK(data_buffer.size() == files.size() * kBytesPerBatchFile, "Unexpected file sizes");
52+
53+
auto targets = torch::empty(num_samples, torch::kByte);
54+
auto images = torch::empty({num_samples, 3, kImageRows, kImageColumns}, torch::kByte);
55+
56+
for (uint32_t i = 0; i != num_samples; ++i) {
57+
// The first byte of each row is the target class index.
58+
uint32_t start_index = i * kBytesPerRow;
59+
targets[i] = data_buffer[start_index];
60+
61+
// The next bytes correspond to the rgb channel values in the following order:
62+
// red (32 *32 = 1024 bytes) | green (1024 bytes) | blue (1024 bytes)
63+
uint32_t image_start = start_index + 1;
64+
uint32_t image_end = image_start + 3 * kBytesPerChannelPerRow;
65+
std::copy(&data_buffer[image_start], &data_buffer[image_end],
66+
reinterpret_cast<char*>(images[i].data_ptr()));
67+
}
68+
69+
return {images.to(torch::kFloat32).div_(255), targets.to(torch::kInt64)};
70+
}
71+
} // namespace
72+
73+
CIFAR10::CIFAR10(const std::string& root, Mode mode) : mode_(mode) {
74+
auto data = read_data(root, mode == Mode::kTrain);
75+
76+
images_ = std::move(data.first);
77+
targets_ = std::move(data.second);
78+
}
79+
80+
torch::data::Example<> CIFAR10::get(size_t index) {
81+
return {images_[index], targets_[index]};
82+
}
83+
84+
torch::optional<size_t> CIFAR10::size() const {
85+
return images_.size(0);
86+
}
87+
88+
bool CIFAR10::is_train() const noexcept {
89+
return mode_ == Mode::kTrain;
90+
}
91+
92+
const torch::Tensor& CIFAR10::images() const {
93+
return images_;
94+
}
95+
96+
const torch::Tensor& CIFAR10::targets() const {
97+
return targets_;
98+
}

0 commit comments

Comments
 (0)