Skip to content

Commit 97adea1

Browse files
authored
Update the readme and fix bugs in custom-dataset example (#1214)
amend amend
1 parent 5921fc1 commit 97adea1

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

cpp/custom-dataset/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
This folder contains an example of loading a custom image dataset with OpenCV and training a model to label images, using the PyTorch C++ frontend.
44

5-
The dataset used here is [Caltech 101](http://www.vision.caltech.edu/Image_Datasets/Caltech101/) dataset.
5+
The dataset used here is [Caltech 101](https://data.caltech.edu/records/mzrjq-6wc02) dataset.
66

77
The entire training code is contained in custom-data.cpp.
88

cpp/custom-dataset/custom-dataset.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <iostream>
66
#include <string>
77
#include <vector>
8+
#include <random>
89

910
struct Options {
1011
int image_size = 224;
@@ -55,7 +56,7 @@ class CustomDataset : public torch::data::datasets::Dataset<CustomDataset> {
5556
auto tdata = torch::cat({R, G, B})
5657
.view({3, options.image_size, options.image_size})
5758
.to(torch::kFloat);
58-
auto tlabel = torch::from_blob(&data[index].second, {1}, torch::kLong);
59+
auto tlabel = torch::tensor(data[index].second, torch::kLong);
5960
return {tdata, tlabel};
6061
}
6162

@@ -65,6 +66,8 @@ class CustomDataset : public torch::data::datasets::Dataset<CustomDataset> {
6566
};
6667

6768
std::pair<Data, Data> readInfo() {
69+
std::random_device randomDevice;
70+
std::mt19937 mersenneTwisterGenerator(randomDevice());
6871
Data train, test;
6972

7073
std::ifstream stream(options.infoFilePath);
@@ -87,8 +90,8 @@ std::pair<Data, Data> readInfo() {
8790
break;
8891
}
8992

90-
std::random_shuffle(train.begin(), train.end());
91-
std::random_shuffle(test.begin(), test.end());
93+
std::shuffle(train.begin(), train.end(), mersenneTwisterGenerator);
94+
std::shuffle(test.begin(), test.end(), mersenneTwisterGenerator);
9295
return std::make_pair(train, test);
9396
}
9497

@@ -119,7 +122,8 @@ struct NetworkImpl : torch::nn::SequentialImpl {
119122
push_back(Linear(4096, 4096));
120123
push_back(Functional(torch::relu));
121124
push_back(Linear(4096, 102));
122-
push_back(Functional(torch::log_softmax, 1, torch::nullopt));
125+
push_back(Functional(
126+
[](torch::Tensor input) { return torch::log_softmax(input, 1); }));
123127
}
124128
};
125129
TORCH_MODULE(Network);

0 commit comments

Comments
 (0)