Skip to content

Commit 2072237

Browse files
committed
Added feedforward neural network
1 parent 91e79d6 commit 2072237

File tree

1 file changed

+93
-0
lines changed
  • tutorials/basics/feedforward_neural_network

1 file changed

+93
-0
lines changed

tutorials/basics/feedforward_neural_network/main.cpp

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,99 @@
22
#include <torch/torch.h>
33
#include <iostream>
44

5+
// Hyper parameters
6+
const int input_size = 784;
7+
const int hidden_size = 500;
8+
const int num_classes = 10;
9+
const int num_epochs = 5;
10+
const int batch_size = 100;
11+
const double learning_rate = 0.001;
12+
13+
struct NeuralNet: torch::nn::Module {
14+
// Declare all the layers of nerual network
15+
torch::nn::Linear fc1{nullptr}, fc2{nullptr};
16+
17+
// Construct all the layers
18+
NeuralNet() {
19+
fc1 = register_module("fc1", torch::nn::Linear(input_size, hidden_size));
20+
fc2 = register_module("fc2", torch::nn::Linear(hidden_size, num_classes));
21+
}
22+
23+
torch::Tensor forward(torch::Tensor x) {
24+
x = torch::relu(fc1->forward(x));
25+
x = fc2->forward(x);
26+
return torch::log_softmax(x, 1);
27+
}
28+
};
29+
530
int main() {
631
std::cout << "FeedForward Neural Network" << std::endl;
32+
33+
// Device
34+
torch::DeviceType device_type;
35+
if (torch::cuda::is_available()) {
36+
std::cout << "CUDA available. Training on GPU." << std::endl;
37+
device_type = torch::kCUDA;
38+
} else {
39+
std::cout << "Training on CPU." << std::endl;
40+
device_type = torch::kCPU;
41+
}
42+
torch::Device device(device_type);
43+
44+
// MNIST Dataset (images and labels)
45+
auto train_dataset = torch::data::datasets::MNIST("../data")
46+
.map(torch::data::transforms::Normalize<>(0.1307, 0.3081))
47+
.map(torch::data::transforms::Stack<>());
48+
auto test_dataset = torch::data::datasets::MNIST("../data", torch::data::datasets::MNIST::Mode::kTest)
49+
.map(torch::data::transforms::Normalize<>(0.1307, 0.3081))
50+
.map(torch::data::transforms::Stack<>());
51+
52+
// Data loader (input pipeline)
53+
auto train_loader = torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(
54+
std::move(train_dataset), batch_size);
55+
auto test_loader = torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(
56+
std::move(test_dataset), batch_size);
57+
58+
// Neural Network model
59+
auto model = std::make_shared<NeuralNet>();
60+
61+
// Loss and optimizer
62+
auto optimizer = torch::optim::SGD(model->parameters(), torch::optim::SGDOptions(learning_rate));
63+
64+
// Train the model
65+
for (int epoch = 0; epoch < num_epochs; epoch++) {
66+
int i = 0;
67+
for (auto& batch : *train_loader) {
68+
auto data = batch.data.to(device), labels = batch.target.to(device);
69+
70+
// Forward pass
71+
auto outputs = model->forward(data);
72+
auto loss = torch::nll_loss(outputs, labels);
73+
74+
// Backward and optimize
75+
optimizer.zero_grad();
76+
loss.backward();
77+
optimizer.step();
78+
79+
if ((i+1) % 5 == 0) {
80+
std::cout << "Epoch [" << (epoch+1) << "/" << num_epochs << "], Batch: "
81+
<< (i+1) << ", Loss: " << loss.item().toFloat() << std::endl;
82+
}
83+
}
84+
}
85+
86+
// Test the model
87+
torch::NoGradGuard no_grad;
88+
int correct = 0;
89+
int total = 0;
90+
for (const auto& batch : *test_loader) {
91+
auto data = batch.data.to(device), labels = batch.target.to(device);
92+
auto outputs = model->forward(data);
93+
auto predicted = outputs.argmax(1);
94+
total += labels.size(0);
95+
correct += predicted.eq(labels).sum().template item<int>();
96+
}
97+
98+
std::cout << "Accuracy of the model on the 10000 test images: " <<
99+
static_cast<double>(100 * correct / total) << std::endl;
7100
}

0 commit comments

Comments
 (0)