Skip to content

Commit 1d9bafb

Browse files
committed
Added Logistic Regression. This closes #2
1 parent f4a0ffa commit 1d9bafb

File tree

3 files changed

+77
-2
lines changed

3 files changed

+77
-2
lines changed

README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22

33
--------------------------------------------------------------------------------
44

5-
This repository provides tutorial code in C++ for deep learning researchers to learn PyTorch.
6-
**Python Tutorial**: [https://github.com/yunjey/pytorch-tutorial](https://github.com/yunjey/pytorch-tutorial)
5+
This repository provides tutorial code in C++ for deep learning researchers to learn PyTorch.
76

87
## Getting Started
98
- Fork/Clone and Install

tutorials/basics/linear_regression/main.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ int main() {
2424
auto criterion = torch::nn::L1Loss();
2525
auto optimizer = torch::optim::SGD(model->parameters(), torch::optim::SGDOptions(learning_rate));
2626

27+
// Train the model
2728
for (int epoch = 0; epoch < num_epochs; epoch++) {
2829
// Array to tensors
2930
auto inputs = x_train;

tutorials/basics/logistic_regression/main.cpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,79 @@
44

55
int main() {
66
std::cout << "Logistic Regression" << std::endl;
7+
8+
// Device
9+
torch::DeviceType device_type;
10+
if (torch::cuda::is_available()) {
11+
std::cout << "CUDA available. Training on GPU." << std::endl;
12+
device_type = torch::kCUDA;
13+
} else {
14+
std::cout << "Training on CPU." << std::endl;
15+
device_type = torch::kCPU;
16+
}
17+
torch::Device device(device_type);
18+
19+
// Hyper parameters
20+
int input_size = 784;
21+
int num_classes = 10;
22+
int num_epochs = 5;
23+
int batch_size = 100;
24+
double learning_rate = 0.001;
25+
26+
// MNIST Dataset (images and labels)
27+
auto train_dataset = torch::data::datasets::MNIST("../data")
28+
.map(torch::data::transforms::Normalize<>(0.1307, 0.3081))
29+
.map(torch::data::transforms::Stack<>());
30+
auto test_dataset = torch::data::datasets::MNIST("../data", torch::data::datasets::MNIST::Mode::kTest)
31+
.map(torch::data::transforms::Normalize<>(0.1307, 0.3081))
32+
.map(torch::data::transforms::Stack<>());
33+
34+
// Data loader (input pipeline)
35+
auto train_loader = torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(
36+
std::move(train_dataset), batch_size);
37+
auto test_loader = torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(
38+
std::move(test_dataset), batch_size);
39+
40+
// Logistic regression model
41+
auto model = torch::nn::Linear(input_size, num_classes);
42+
43+
// Loss and optimizer
44+
auto optimizer = torch::optim::SGD(model->parameters(), torch::optim::SGDOptions(learning_rate));
45+
46+
// Train the model
47+
for (int epoch = 0; epoch < num_epochs; epoch++) {
48+
int i = 0;
49+
for (auto& batch : *train_loader) {
50+
auto data = batch.data.to(device), labels = batch.target.to(device);
51+
52+
// Forward pass
53+
auto outputs = model->forward(data);
54+
auto loss = torch::nll_loss(outputs, labels);
55+
56+
// Backward and optimize
57+
optimizer.zero_grad();
58+
loss.backward();
59+
optimizer.step();
60+
61+
if ((i+1) % 5 == 0) {
62+
std::cout << "Epoch [" << (epoch+1) << "/" << num_epochs << "], Batch: "
63+
<< (i+1) << ", Loss: " << loss.item().toFloat() << std::endl;
64+
}
65+
}
66+
}
67+
68+
// Test the model
69+
torch::NoGradGuard no_grad;
70+
int correct = 0;
71+
int total = 0;
72+
for (const auto& batch : *test_loader) {
73+
auto data = batch.data.to(device), labels = batch.target.to(device);
74+
auto outputs = model->forward(data);
75+
auto predicted = outputs.argmax(1);
76+
total += labels.size(0);
77+
correct += predicted.eq(labels).sum().template item<int>();
78+
}
79+
80+
std::cout << "Accuracy of the model on the 10000 test images: " <<
81+
static_cast<double>(100 * correct / total) << std::endl;
782
}

0 commit comments

Comments
 (0)