Skip to content

Commit ca8adf9

Browse files
committed
Change Cifar10 to CIFAR10
1 parent 60a1d52 commit ca8adf9

File tree

3 files changed

+16
-16
lines changed

3 files changed

+16
-16
lines changed

tutorials/intermediate/deep_residual_network/include/cifar10.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,26 @@
77
#include <cstddef>
88
#include <string>
99

10-
// Cifar10 dataset
10+
// CIFAR10 dataset
1111
// based on: https://github.com/pytorch/pytorch/blob/master/torch/csrc/api/include/torch/data/datasets/mnist.h.
12-
class Cifar10 : public torch::data::datasets::Dataset<Cifar10> {
12+
class CIFAR10 : public torch::data::datasets::Dataset<CIFAR10> {
1313
public:
1414
// The mode in which the dataset is loaded
1515
enum Mode { kTrain, kTest };
1616

17-
// Loads the Cifar10 dataset from the `root` path.
17+
// Loads the CIFAR10 dataset from the `root` path.
1818
//
1919
// The supplied `root` path should contain the *content* of the unzipped
20-
// Cifar10 dataset (binary version), available from http://www.cs.toronto.edu/~kriz/cifar.html.
21-
explicit Cifar10(const std::string& root, Mode mode = Mode::kTrain);
20+
// CIFAR10 dataset (binary version), available from http://www.cs.toronto.edu/~kriz/cifar.html.
21+
explicit CIFAR10(const std::string& root, Mode mode = Mode::kTrain);
2222

2323
// Returns the `Example` at the given `index`.
2424
torch::data::Example<> get(size_t index) override;
2525

2626
// Returns the size of the dataset.
2727
torch::optional<size_t> size() const override;
2828

29-
// Returns true if this is the training subset of Cifar10.
29+
// Returns true if this is the training subset of CIFAR10.
3030
bool is_train() const noexcept;
3131

3232
// Returns all images stacked into a single tensor.

tutorials/intermediate/deep_residual_network/src/cifar10.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#include "cifar10.h"
33

44
namespace {
5-
// Cifar10 dataset description can be found at https://www.cs.toronto.edu/~kriz/cifar.html.
5+
// CIFAR10 dataset description can be found at https://www.cs.toronto.edu/~kriz/cifar.html.
66
constexpr uint32_t kTrainSize = 50000;
77
constexpr uint32_t kTestSize = 10000;
88
constexpr uint32_t kSizePerBatch = 10000;
@@ -70,29 +70,29 @@ std::pair<torch::Tensor, torch::Tensor> read_data(const std::string& root, bool
7070
}
7171
} // namespace
7272

73-
Cifar10::Cifar10(const std::string& root, Mode mode) : mode_(mode) {
73+
CIFAR10::CIFAR10(const std::string& root, Mode mode) : mode_(mode) {
7474
auto data = read_data(root, mode == Mode::kTrain);
7575

7676
images_ = std::move(data.first);
7777
targets_ = std::move(data.second);
7878
}
7979

80-
torch::data::Example<> Cifar10::get(size_t index) {
80+
torch::data::Example<> CIFAR10::get(size_t index) {
8181
return {images_[index], targets_[index]};
8282
}
8383

84-
torch::optional<size_t> Cifar10::size() const {
84+
torch::optional<size_t> CIFAR10::size() const {
8585
return images_.size(0);
8686
}
8787

88-
bool Cifar10::is_train() const noexcept {
88+
bool CIFAR10::is_train() const noexcept {
8989
return mode_ == Mode::kTrain;
9090
}
9191

92-
const torch::Tensor& Cifar10::images() const {
92+
const torch::Tensor& CIFAR10::images() const {
9393
return images_;
9494
}
9595

96-
const torch::Tensor& Cifar10::targets() const {
96+
const torch::Tensor& CIFAR10::targets() const {
9797
return targets_;
9898
}

tutorials/intermediate/deep_residual_network/src/main.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ int main() {
2828

2929
const std::string CIFAR_data_path = "../../../../tutorials/intermediate/deep_residual_network/data/";
3030

31-
// Cifar10 custom dataset
32-
auto train_dataset = Cifar10(CIFAR_data_path)
31+
// CIFAR10 custom dataset
32+
auto train_dataset = CIFAR10(CIFAR_data_path)
3333
.map(ConstantPad(4))
3434
.map(RandomHorizontalFlip())
3535
.map(RandomCrop({32, 32}))
@@ -38,7 +38,7 @@ int main() {
3838
// Number of samples in the training set
3939
auto num_train_samples = train_dataset.size().value();
4040

41-
auto test_dataset = Cifar10(CIFAR_data_path, Cifar10::Mode::kTest)
41+
auto test_dataset = CIFAR10(CIFAR_data_path, CIFAR10::Mode::kTest)
4242
.map(torch::data::transforms::Stack<>());
4343

4444
// Number of samples in the testset

0 commit comments

Comments
 (0)