Skip to content

Commit 7a8c5be

Browse files
authored
Add files via upload
1 parent e158589 commit 7a8c5be

File tree

2 files changed

+26
-23
lines changed

2 files changed

+26
-23
lines changed

Multiclass_Classification/EfficientNet/src/networks.cpp

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,10 @@ void StochasticDepthImpl::pretty_print(std::ostream& stream) const{
5454
// ----------------------------------------------------------------------
5555
// struct{Conv2dNormActivationImpl}(nn::Module) -> constructor
5656
// ----------------------------------------------------------------------
57-
Conv2dNormActivationImpl::Conv2dNormActivationImpl(const size_t in_nc, const size_t out_nc, const size_t kernel_size, const size_t stride, const size_t padding, const size_t groups, const bool SiLU){
57+
Conv2dNormActivationImpl::Conv2dNormActivationImpl(const size_t in_nc, const size_t out_nc, const size_t kernel_size, const size_t stride, const size_t padding, const size_t groups, const float eps, const float momentum, const bool SiLU){
5858
this->model = nn::Sequential(
5959
nn::Conv2d(nn::Conv2dOptions(/*in_channels=*/in_nc, /*out_channels=*/out_nc, /*kernel_size=*/kernel_size).stride(stride).padding(padding).groups(groups).bias(false)),
60-
nn::BatchNorm2d(out_nc)
60+
nn::BatchNorm2d(nn::BatchNormOptions(out_nc).eps(eps).momentum(momentum))
6161
);
6262
if (SiLU) this->model->push_back(nn::SiLU());
6363
register_module("Conv2dNormActivation", this->model);
@@ -113,16 +113,16 @@ torch::Tensor SqueezeExcitationImpl::forward(torch::Tensor x){
113113
// ----------------------------------------------------------------------
114114
// struct{MBConvImpl}(nn::Module) -> constructor
115115
// ----------------------------------------------------------------------
116-
MBConvImpl::MBConvImpl(const size_t in_nc, const size_t out_nc, const size_t kernel_size, const size_t stride, const size_t exp, const float dropconnect){
116+
MBConvImpl::MBConvImpl(const size_t in_nc, const size_t out_nc, const size_t kernel_size, const size_t stride, const size_t exp, const float eps, const float momentum, const float dropconnect){
117117

118118
constexpr size_t reduce = 4;
119119
size_t mid = in_nc * exp;
120120
this->residual = ((stride == 1) && (in_nc == out_nc));
121121

122-
if (exp != 1) this->block->push_back(Conv2dNormActivation(in_nc, mid, /*kernel_size=*/1, /*stride=*/1, /*padding=*/0, /*groups=*/1, /*SiLU=*/true));
123-
this->block->push_back(Conv2dNormActivation(mid, mid, /*kernel_size=*/kernel_size, /*stride=*/stride, /*padding=*/kernel_size / 2, /*groups=*/mid, /*SiLU=*/true));
122+
if (exp != 1) this->block->push_back(Conv2dNormActivation(in_nc, mid, /*kernel_size=*/1, /*stride=*/1, /*padding=*/0, /*groups=*/1, /*eps=*/eps, /*momentum=*/momentum, /*SiLU=*/true));
123+
this->block->push_back(Conv2dNormActivation(mid, mid, /*kernel_size=*/kernel_size, /*stride=*/stride, /*padding=*/kernel_size / 2, /*groups=*/mid, /*eps=*/eps, /*momentum=*/momentum, /*SiLU=*/true));
124124
this->block->push_back(SqueezeExcitation(mid, std::max(1, int(in_nc / reduce))));
125-
this->block->push_back(Conv2dNormActivation(mid, out_nc, /*kernel_size=*/1, /*stride=*/1, /*padding=*/0, /*groups=*/1, /*SiLU=*/false));
125+
this->block->push_back(Conv2dNormActivation(mid, out_nc, /*kernel_size=*/1, /*stride=*/1, /*padding=*/0, /*groups=*/1, /*eps=*/eps, /*momentum=*/momentum, /*SiLU=*/false));
126126
register_module("block", this->block);
127127

128128
this->sd = StochasticDepth(dropconnect);
@@ -170,7 +170,7 @@ size_t MC_EfficientNetImpl::round_filters(size_t c, double width_mul){
170170
// struct{MC_EfficientNetImpl}(nn::Module) -> function{round_repeats}
171171
// ----------------------------------------------------------------------
172172
size_t MC_EfficientNetImpl::round_repeats(size_t r, double depth_mul){
173-
return std::max(1, int(std::round(r * depth_mul)));
173+
return std::max(1, int(std::ceil(r * depth_mul)));
174174
}
175175

176176

@@ -186,16 +186,16 @@ MC_EfficientNetImpl::MC_EfficientNetImpl(po::variables_map &vm){
186186

187187
// (0.a) Setting for network's config
188188
std::string network = vm["network"].as<std::string>();
189-
if (network == "B0") this->cfg = {1.0, 1.0, 224, 0.2};
190-
else if (network == "B1") this->cfg = {1.0, 1.1, 240, 0.2};
191-
else if (network == "B2") this->cfg = {1.1, 1.2, 260, 0.3};
192-
else if (network == "B3") this->cfg = {1.2, 1.4, 300, 0.3};
193-
else if (network == "B4") this->cfg = {1.4, 1.8, 380, 0.4};
194-
else if (network == "B5") this->cfg = {1.6, 2.2, 456, 0.4};
195-
else if (network == "B6") this->cfg = {1.8, 2.6, 528, 0.5};
196-
else if (network == "B7") this->cfg = {2.0, 3.1, 600, 0.5};
197-
else if (network == "B8") this->cfg = {2.2, 3.6, 672, 0.5};
198-
else if (network == "L2") this->cfg = {4.3, 5.3, 800, 0.5};
189+
if (network == "B0") this->cfg = {1.0, 1.0, 224, 0.2, 1e-5, 0.1, 0.2};
190+
else if (network == "B1") this->cfg = {1.0, 1.1, 240, 0.2, 1e-5, 0.1, 0.2};
191+
else if (network == "B2") this->cfg = {1.1, 1.2, 260, 0.3, 1e-5, 0.1, 0.2};
192+
else if (network == "B3") this->cfg = {1.2, 1.4, 300, 0.3, 1e-5, 0.1, 0.2};
193+
else if (network == "B4") this->cfg = {1.4, 1.8, 380, 0.4, 1e-5, 0.1, 0.2};
194+
else if (network == "B5") this->cfg = {1.6, 2.2, 456, 0.4, 0.001, 0.01, 0.2};
195+
else if (network == "B6") this->cfg = {1.8, 2.6, 528, 0.5, 0.001, 0.01, 0.2};
196+
else if (network == "B7") this->cfg = {2.0, 3.1, 600, 0.5, 0.001, 0.01, 0.2};
197+
else if (network == "B8") this->cfg = {2.2, 3.6, 672, 0.5, 0.001, 0.01, 0.2};
198+
else if (network == "L2") this->cfg = {4.3, 5.3, 800, 0.5, 0.001, 0.01, 0.2};
199199
else{
200200
std::cerr << "Error : The type of network is " << network << '.' << std::endl;
201201
std::cerr << "Error : Please choose B0, B1, B2, B3, B4, B5, B6, B7, B8 or L2." << std::endl;
@@ -209,7 +209,7 @@ MC_EfficientNetImpl::MC_EfficientNetImpl(po::variables_map &vm){
209209

210210
// (1) Stem layer
211211
stem_nc = this->round_filters(stem_feature, this->cfg.width_mul);
212-
this->features->push_back(Conv2dNormActivation(vm["nc"].as<size_t>(), stem_nc, /*kernel_size=*/3, /*stride=*/2, /*padding=*/1, /*groups=*/1, /*SiLU=*/true));
212+
this->features->push_back(Conv2dNormActivation(vm["nc"].as<size_t>(), stem_nc, /*kernel_size=*/3, /*stride=*/2, /*padding=*/1, /*groups=*/1, /*eps=*/this->cfg.eps, /*momentum=*/this->cfg.momentum, /*SiLU=*/true));
213213

214214
// (2.a) Bone layer
215215
total_blocks = 0;
@@ -225,16 +225,16 @@ MC_EfficientNetImpl::MC_EfficientNetImpl(po::variables_map &vm){
225225
out_nc = this->round_filters(bcfg[i].c, this->cfg.width_mul);
226226
for (size_t j = 0; j < repeats; j++){
227227
stride = (j == 0) ? bcfg[i].s : 1;
228-
dropconnect = this->cfg.dropout * (double)block_idx / (double)std::max(1, int(total_blocks));
229-
this->features->push_back(MBConvImpl(in_nc, out_nc, bcfg[i].k, stride, bcfg[i].exp, dropconnect));
228+
dropconnect = this->cfg.stochastic_depth_prob * (double)block_idx / (double)std::max(1, int(total_blocks));
229+
this->features->push_back(MBConvImpl(in_nc, out_nc, bcfg[i].k, stride, bcfg[i].exp, this->cfg.eps, this->cfg.momentum, dropconnect));
230230
in_nc = out_nc;
231231
block_idx++;
232232
}
233233
}
234234

235235
// (3) Head layer
236236
head_nc = this->round_filters(head_feature, this->cfg.width_mul);
237-
this->features->push_back(Conv2dNormActivation(in_nc, head_nc, /*kernel_size=*/1, /*stride=*/1, /*padding=*/0, /*groups=*/1, /*SiLU=*/true));
237+
this->features->push_back(Conv2dNormActivation(in_nc, head_nc, /*kernel_size=*/1, /*stride=*/1, /*padding=*/0, /*groups=*/1, /*eps=*/this->cfg.eps, /*momentum=*/this->cfg.momentum, /*SiLU=*/true));
238238
register_module("features", this->features);
239239

240240
// (4) Global Average Pooling

Multiclass_Classification/EfficientNet/src/networks.hpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ struct EfficientNetConfig{
2121
double depth_mul;
2222
size_t image_size;
2323
double dropout;
24+
double eps;
25+
double momentum;
26+
double stochastic_depth_prob;
2427
};
2528

2629
// -------------------------------------------------
@@ -52,7 +55,7 @@ struct Conv2dNormActivationImpl : nn::Module{
5255
nn::Sequential model;
5356
public:
5457
Conv2dNormActivationImpl(){}
55-
Conv2dNormActivationImpl(const size_t in_nc, const size_t out_nc, const size_t kernel_size, const size_t stride, const size_t padding, const size_t groups, const bool SiLU);
58+
Conv2dNormActivationImpl(const size_t in_nc, const size_t out_nc, const size_t kernel_size, const size_t stride, const size_t padding, const size_t groups, const float eps, const float momentum, const bool SiLU);
5659
torch::Tensor forward(torch::Tensor x);
5760
};
5861
TORCH_MODULE(Conv2dNormActivation);
@@ -84,7 +87,7 @@ struct MBConvImpl : nn::Module{
8487
StochasticDepth sd;
8588
public:
8689
MBConvImpl(){}
87-
MBConvImpl(const size_t in_nc, const size_t out_nc, const size_t kernel_size, const size_t stride, const size_t exp, const float dropconnect);
90+
MBConvImpl(const size_t in_nc, const size_t out_nc, const size_t kernel_size, const size_t stride, const size_t exp, const float eps, const float momentum, const float dropconnect);
8891
torch::Tensor forward(torch::Tensor x);
8992
};
9093
TORCH_MODULE(MBConv);

0 commit comments

Comments
 (0)