5
5
#include < iostream>
6
6
#include < string>
7
7
#include < vector>
8
+ #include < random>
8
9
9
10
struct Options {
10
11
int image_size = 224 ;
@@ -55,7 +56,7 @@ class CustomDataset : public torch::data::datasets::Dataset<CustomDataset> {
55
56
auto tdata = torch::cat ({R, G, B})
56
57
.view ({3 , options.image_size , options.image_size })
57
58
.to (torch::kFloat );
58
- auto tlabel = torch::from_blob (& data[index].second , { 1 } , torch::kLong );
59
+ auto tlabel = torch::tensor ( data[index].second , torch::kLong );
59
60
return {tdata, tlabel};
60
61
}
61
62
@@ -65,6 +66,8 @@ class CustomDataset : public torch::data::datasets::Dataset<CustomDataset> {
65
66
};
66
67
67
68
std::pair<Data, Data> readInfo () {
69
+ std::random_device randomDevice;
70
+ std::mt19937 mersenneTwisterGenerator (randomDevice ());
68
71
Data train, test;
69
72
70
73
std::ifstream stream (options.infoFilePath );
@@ -87,8 +90,8 @@ std::pair<Data, Data> readInfo() {
87
90
break ;
88
91
}
89
92
90
- std::random_shuffle (train.begin (), train.end ());
91
- std::random_shuffle (test.begin (), test.end ());
93
+ std::shuffle (train.begin (), train.end (), mersenneTwisterGenerator );
94
+ std::shuffle (test.begin (), test.end (), mersenneTwisterGenerator );
92
95
return std::make_pair (train, test);
93
96
}
94
97
@@ -119,7 +122,8 @@ struct NetworkImpl : torch::nn::SequentialImpl {
119
122
push_back (Linear (4096 , 4096 ));
120
123
push_back (Functional (torch::relu));
121
124
push_back (Linear (4096 , 102 ));
122
- push_back (Functional (torch::log_softmax, 1 , torch::nullopt));
125
+ push_back (Functional (
126
+ [](torch::Tensor input) { return torch::log_softmax (input, 1 ); }));
123
127
}
124
128
};
125
129
TORCH_MODULE (Network);
0 commit comments