Skip to content

Commit e5e700e

Browse files
committed
blitz: Added cifar, training a classifier tutorial
1 parent 196b8dd commit e5e700e

File tree

9 files changed

+318
-12
lines changed

9 files changed

+318
-12
lines changed

tutorials/popular/blitz/autograd/main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ int main() {
4545
// Example of vector-Jacobian product:
4646
x = torch::randn(3, torch::TensorOptions().requires_grad(true));
4747
y = x * 2;
48-
while (y.data().norm().item().toInt() < 1000) {
48+
while (y.data().norm().item<int>() < 1000) {
4949
y = y * 2;
5050
}
5151
std::cout << "y:\n" << y << '\n';

tutorials/popular/blitz/tensors/main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,5 +63,5 @@ int main() {
6363
// If you have a one element tensor, use .item() to get the value as a Python number
6464
x = torch::randn(1);
6565
std::cout << "x:\n" << x << '\n';
66-
std::cout << "x.item():\n" << x.item().toFloat() << '\n';
66+
std::cout << "x.item():\n" << x.item<float>() << '\n';
6767
}

tutorials/popular/blitz/training_a_classifier/CMakeLists.txt

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,14 @@ endif()
1010
set(EXECUTABLE_NAME training-a-classifier)
1111

1212
add_executable(${EXECUTABLE_NAME})
13-
target_sources(${EXECUTABLE_NAME} PRIVATE main.cpp)
13+
target_sources(${EXECUTABLE_NAME} PRIVATE src/main.cpp
14+
src/nnet.cpp
15+
src/cifar10.cpp
16+
include/nnet.h
17+
include/cifar10.h
18+
)
19+
20+
target_include_directories(${EXECUTABLE_NAME} PRIVATE include)
1421

1522
target_link_libraries(${EXECUTABLE_NAME} ${TORCH_LIBRARIES})
1623

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// Copyright 2020-present pytorch-cpp Authors
2+
#pragma once
3+
4+
#include <torch/data/datasets/base.h>
5+
#include <torch/data/example.h>
6+
#include <torch/types.h>
7+
#include <cstddef>
8+
#include <fstream>
9+
#include <string>
10+
11+
// CIFAR10 dataset
12+
// based on: https://github.com/pytorch/pytorch/blob/master/torch/csrc/api/include/torch/data/datasets/mnist.h.
13+
class CIFAR10 : public torch::data::datasets::Dataset<CIFAR10> {
14+
public:
15+
// The mode in which the dataset is loaded
16+
enum Mode { kTrain, kTest };
17+
18+
// Loads the CIFAR10 dataset from the `root` path.
19+
//
20+
// The supplied `root` path should contain the *content* of the unzipped
21+
// CIFAR10 dataset (binary version), available from http://www.cs.toronto.edu/~kriz/cifar.html.
22+
explicit CIFAR10(const std::string& root, Mode mode = Mode::kTrain);
23+
24+
// Returns the `Example` at the given `index`.
25+
torch::data::Example<> get(size_t index) override;
26+
27+
// Returns the size of the dataset.
28+
torch::optional<size_t> size() const override;
29+
30+
// Returns true if this is the training subset of CIFAR10.
31+
bool is_train() const noexcept;
32+
33+
// Returns all images stacked into a single tensor.
34+
const torch::Tensor& images() const;
35+
36+
// Returns all targets stacked into a single tensor.
37+
const torch::Tensor& targets() const;
38+
39+
private:
40+
torch::Tensor images_;
41+
torch::Tensor targets_;
42+
Mode mode_;
43+
};
44+
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// Copyright 2020-present pytorch-cpp Authors
2+
#pragma once
3+
4+
#include <torch/torch.h>
5+
6+
class NetImpl : public torch::nn::Module {
7+
public:
8+
NetImpl();
9+
torch::Tensor forward(torch::Tensor x);
10+
11+
private:
12+
torch::nn::Conv2d conv1;
13+
torch::nn::MaxPool2d pool;
14+
torch::nn::Conv2d conv2;
15+
torch::nn::Linear fc1;
16+
torch::nn::Linear fc2;
17+
torch::nn::Linear fc3;
18+
};
19+
20+
TORCH_MODULE(Net);

tutorials/popular/blitz/training_a_classifier/main.cpp

Lines changed: 0 additions & 9 deletions
This file was deleted.
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
// Copyright 2020-present pytorch-cpp Authors
2+
#include "cifar10.h"
3+
4+
namespace {
5+
// CIFAR10 dataset description can be found at https://www.cs.toronto.edu/~kriz/cifar.html.
6+
constexpr uint32_t kTrainSize = 50000;
7+
constexpr uint32_t kTestSize = 10000;
8+
constexpr uint32_t kSizePerBatch = 10000;
9+
constexpr uint32_t kImageRows = 32;
10+
constexpr uint32_t kImageColumns = 32;
11+
constexpr uint32_t kBytesPerRow = 3073;
12+
constexpr uint32_t kBytesPerChannelPerRow = 1024;
13+
constexpr uint32_t kBytesPerBatchFile = kBytesPerRow * kSizePerBatch;
14+
15+
const std::vector<std::string> kTrainDataBatchFiles = {
16+
"data_batch_1.bin",
17+
"data_batch_2.bin",
18+
"data_batch_3.bin",
19+
"data_batch_4.bin",
20+
"data_batch_5.bin",
21+
};
22+
23+
const std::vector<std::string> kTestDataBatchFiles = {
24+
"test_batch.bin"
25+
};
26+
27+
// Source: https://github.com/pytorch/pytorch/blob/master/torch/csrc/api/src/data/datasets/mnist.cpp.
28+
std::string join_paths(std::string head, const std::string& tail) {
29+
if (head.back() != '/') {
30+
head.push_back('/');
31+
}
32+
head += tail;
33+
return head;
34+
}
35+
// Partially based on https://github.com/pytorch/pytorch/blob/master/torch/csrc/api/src/data/datasets/mnist.cpp.
36+
std::pair<torch::Tensor, torch::Tensor> read_data(const std::string& root, bool train) {
37+
const auto& files = train ? kTrainDataBatchFiles : kTestDataBatchFiles;
38+
const auto num_samples = train ? kTrainSize : kTestSize;
39+
40+
std::vector<char> data_buffer;
41+
data_buffer.reserve(files.size() * kBytesPerBatchFile);
42+
43+
for (const auto& file : files) {
44+
const auto path = join_paths(root, file);
45+
std::ifstream data(path, std::ios::binary);
46+
TORCH_CHECK(data, "Error opening data file at", path);
47+
48+
data_buffer.insert(data_buffer.end(), std::istreambuf_iterator<char>(data), {});
49+
}
50+
51+
TORCH_CHECK(data_buffer.size() == files.size() * kBytesPerBatchFile, "Unexpected file sizes");
52+
53+
auto targets = torch::empty(num_samples, torch::kByte);
54+
auto images = torch::empty({num_samples, 3, kImageRows, kImageColumns}, torch::kByte);
55+
56+
for (uint32_t i = 0; i != num_samples; ++i) {
57+
// The first byte of each row is the target class index.
58+
uint32_t start_index = i * kBytesPerRow;
59+
targets[i] = data_buffer[start_index];
60+
61+
// The next bytes correspond to the rgb channel values in the following order:
62+
// red (32 *32 = 1024 bytes) | green (1024 bytes) | blue (1024 bytes)
63+
uint32_t image_start = start_index + 1;
64+
uint32_t image_end = image_start + 3 * kBytesPerChannelPerRow;
65+
std::copy(data_buffer.begin() + image_start, data_buffer.begin() + image_end,
66+
reinterpret_cast<char*>(images[i].data_ptr()));
67+
}
68+
69+
return {images.to(torch::kFloat32).div_(255), targets.to(torch::kInt64)};
70+
}
71+
} // namespace
72+
73+
CIFAR10::CIFAR10(const std::string& root, Mode mode) : mode_(mode) {
74+
auto data = read_data(root, mode == Mode::kTrain);
75+
76+
images_ = std::move(data.first);
77+
targets_ = std::move(data.second);
78+
}
79+
80+
torch::data::Example<> CIFAR10::get(size_t index) {
81+
return {images_[index], targets_[index]};
82+
}
83+
84+
torch::optional<size_t> CIFAR10::size() const {
85+
return images_.size(0);
86+
}
87+
88+
bool CIFAR10::is_train() const noexcept {
89+
return mode_ == Mode::kTrain;
90+
}
91+
92+
const torch::Tensor& CIFAR10::images() const {
93+
return images_;
94+
}
95+
96+
const torch::Tensor& CIFAR10::targets() const {
97+
return targets_;
98+
}
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
// Copyright 2020-present pytorch-cpp Authors
2+
#include <torch/torch.h>
3+
#include <iostream>
4+
#include <vector>
5+
#include <iomanip>
6+
#include "nnet.h"
7+
#include "cifar10.h"
8+
9+
int main() {
10+
std::cout << "Deep Learning with PyTorch: A 60 Minute Blitz\n\n";
11+
std::cout << "Training a Classifier\n\n";
12+
13+
// Loading and normalizing CIFAR10
14+
const std::string CIFAR_data_path = "../../../../../data/cifar10/";
15+
16+
auto train_dataset = CIFAR10(CIFAR_data_path)
17+
.map(torch::data::transforms::Normalize<>({0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}))
18+
.map(torch::data::transforms::Stack<>());
19+
auto train_loader = torch::data::make_data_loader<torch::data::samplers::RandomSampler>(
20+
std::move(train_dataset), 4);
21+
22+
auto test_dataset = CIFAR10(CIFAR_data_path, CIFAR10::Mode::kTest)
23+
.map(torch::data::transforms::Normalize<>({0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}))
24+
.map(torch::data::transforms::Stack<>());
25+
auto test_loader = torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(
26+
std::move(test_dataset), 4);
27+
28+
std::string classes[10] = {"plane", "car", "bird", "cat",
29+
"deer", "dog", "frog", "horse", "ship", "truck"};
30+
31+
// Define a Convolutional Neural Network
32+
Net net = Net();
33+
net->to(torch::kCPU);
34+
35+
// // Define a Loss function and optimizer
36+
torch::nn::CrossEntropyLoss criterion;
37+
torch::optim::SGD optimizer(net->parameters(), torch::optim::SGDOptions(0.001).momentum(0.9));
38+
39+
// Train the network
40+
for (size_t epoch = 0; epoch < 2; ++epoch) {
41+
double running_loss = 0.0;
42+
43+
int i = 0;
44+
for (auto& batch : *train_loader) {
45+
// get the inputs; data is a list of [inputs, labels]
46+
auto inputs = batch.data.to(torch::kCPU);
47+
auto labels = batch.target.to(torch::kCPU);
48+
49+
// zero the parameter gradients
50+
optimizer.zero_grad();
51+
52+
// forward + backward + optimize
53+
auto outputs = net->forward(inputs);
54+
auto loss = criterion(outputs, labels);
55+
loss.backward();
56+
optimizer.step();
57+
58+
// print statistics
59+
running_loss += loss.item<double>();
60+
if (i % 2000 == 1999) { // print every 2000 mini-batches
61+
std::cout << "[" << epoch + 1 << ", " << i + 1 << "] loss: "
62+
<< running_loss / 2000 << '\n';
63+
running_loss = 0.0;
64+
}
65+
i++;
66+
}
67+
}
68+
std::cout << "Finished Training\n\n";
69+
70+
std::string PATH = "./cifar_net.pth";
71+
// torch::save(net, PATH);
72+
73+
// Test the network on the test data
74+
net = Net();
75+
torch::load(net, PATH);
76+
77+
int correct = 0;
78+
int total = 0;
79+
for (const auto& batch : *test_loader) {
80+
auto images = batch.data.to(torch::kCPU);
81+
auto labels = batch.target.to(torch::kCPU);
82+
83+
auto outputs = net->forward(images);
84+
85+
auto out_tuple = torch::max(outputs, 1);
86+
auto predicted = std::get<1>(out_tuple);
87+
total += labels.size(0);
88+
correct += (predicted == labels).sum().item<int>();
89+
}
90+
91+
std::cout << "Accuracy of the network on the 10000 test images: "
92+
<< (100 * correct / total) << "%\n\n";
93+
94+
float class_correct[10];
95+
float class_total[10];
96+
97+
torch::NoGradGuard no_grad;
98+
99+
for (const auto& batch : *test_loader) {
100+
auto images = batch.data.to(torch::kCPU);
101+
auto labels = batch.target.to(torch::kCPU);
102+
103+
auto outputs = net->forward(images);
104+
105+
auto out_tuple = torch::max(outputs, 1);
106+
auto predicted = std::get<1>(out_tuple);
107+
auto c = (predicted == labels).squeeze();
108+
109+
for (int i = 0; i < 4; ++i) {
110+
auto label = labels[i].item<int>();
111+
class_correct[label] += c[i].item<float>();
112+
class_total[label] += 1;
113+
}
114+
}
115+
116+
for (int i = 0; i < 10; ++i) {
117+
std::cout << "Accuracy of " << classes[i] << " "
118+
<< 100 * class_correct[i] / class_total[i] << "%\n";
119+
}
120+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// Copyright 2020-present pytorch-cpp Authors
2+
#include "nnet.h"
3+
#include <torch/torch.h>
4+
5+
NetImpl::NetImpl() :
6+
conv1(torch::nn::Conv2dOptions(3, 6, 5)),
7+
pool(torch::nn::MaxPool2dOptions({2, 2})),
8+
conv2(torch::nn::Conv2dOptions(6, 16, 5)),
9+
fc1(torch::nn::LinearOptions(16 * 5 * 5, 120)),
10+
fc2(torch::nn::LinearOptions(120, 84)),
11+
fc3(torch::nn::LinearOptions(84, 10)) {
12+
register_module("conv1", conv1);
13+
register_module("conv2", conv2);
14+
register_module("fc1", fc1);
15+
register_module("fc2", fc2);
16+
register_module("fc3", fc3);
17+
}
18+
19+
torch::Tensor NetImpl::forward(torch::Tensor x) {
20+
auto out = pool->forward(torch::relu(conv1->forward(x)));
21+
out = pool->forward(torch::relu(conv2->forward(out)));
22+
out = out.view({-1, 16 * 5 * 5});
23+
out = torch::relu(fc1->forward(out));
24+
out = torch::relu(fc2->forward(out));
25+
return fc3->forward(out);
26+
}

0 commit comments

Comments
 (0)