@@ -54,10 +54,10 @@ void StochasticDepthImpl::pretty_print(std::ostream& stream) const{
5454// ----------------------------------------------------------------------
5555// struct{Conv2dNormActivationImpl}(nn::Module) -> constructor
5656// ----------------------------------------------------------------------
57- Conv2dNormActivationImpl::Conv2dNormActivationImpl (const size_t in_nc, const size_t out_nc, const size_t kernel_size, const size_t stride, const size_t padding, const size_t groups, const bool SiLU){
57+ Conv2dNormActivationImpl::Conv2dNormActivationImpl (const size_t in_nc, const size_t out_nc, const size_t kernel_size, const size_t stride, const size_t padding, const size_t groups, const float eps, const float momentum, const bool SiLU){
5858 this ->model = nn::Sequential (
5959 nn::Conv2d (nn::Conv2dOptions (/* in_channels=*/ in_nc, /* out_channels=*/ out_nc, /* kernel_size=*/ kernel_size).stride (stride).padding (padding).groups (groups).bias (false )),
60- nn::BatchNorm2d (out_nc)
60+ nn::BatchNorm2d (nn::BatchNormOptions ( out_nc). eps (eps). momentum (momentum) )
6161 );
6262 if (SiLU) this ->model ->push_back (nn::SiLU ());
6363 register_module (" Conv2dNormActivation" , this ->model );
@@ -113,16 +113,16 @@ torch::Tensor SqueezeExcitationImpl::forward(torch::Tensor x){
113113// ----------------------------------------------------------------------
114114// struct{MBConvImpl}(nn::Module) -> constructor
115115// ----------------------------------------------------------------------
116- MBConvImpl::MBConvImpl (const size_t in_nc, const size_t out_nc, const size_t kernel_size, const size_t stride, const size_t exp, const float dropconnect){
116+ MBConvImpl::MBConvImpl (const size_t in_nc, const size_t out_nc, const size_t kernel_size, const size_t stride, const size_t exp, const float eps, const float momentum, const float dropconnect){
117117
118118 constexpr size_t reduce = 4 ;
119119 size_t mid = in_nc * exp;
120120 this ->residual = ((stride == 1 ) && (in_nc == out_nc));
121121
122- if (exp != 1 ) this ->block ->push_back (Conv2dNormActivation (in_nc, mid, /* kernel_size=*/ 1 , /* stride=*/ 1 , /* padding=*/ 0 , /* groups=*/ 1 , /* SiLU=*/ true ));
123- this ->block ->push_back (Conv2dNormActivation (mid, mid, /* kernel_size=*/ kernel_size, /* stride=*/ stride, /* padding=*/ kernel_size / 2 , /* groups=*/ mid, /* SiLU=*/ true ));
122+ if (exp != 1 ) this ->block ->push_back (Conv2dNormActivation (in_nc, mid, /* kernel_size=*/ 1 , /* stride=*/ 1 , /* padding=*/ 0 , /* groups=*/ 1 , /* eps= */ eps, /* momentum= */ momentum, /* SiLU=*/ true ));
123+ this ->block ->push_back (Conv2dNormActivation (mid, mid, /* kernel_size=*/ kernel_size, /* stride=*/ stride, /* padding=*/ kernel_size / 2 , /* groups=*/ mid, /* eps= */ eps, /* momentum= */ momentum, /* SiLU=*/ true ));
124124 this ->block ->push_back (SqueezeExcitation (mid, std::max (1 , int (in_nc / reduce))));
125- this ->block ->push_back (Conv2dNormActivation (mid, out_nc, /* kernel_size=*/ 1 , /* stride=*/ 1 , /* padding=*/ 0 , /* groups=*/ 1 , /* SiLU=*/ false ));
125+ this ->block ->push_back (Conv2dNormActivation (mid, out_nc, /* kernel_size=*/ 1 , /* stride=*/ 1 , /* padding=*/ 0 , /* groups=*/ 1 , /* eps= */ eps, /* momentum= */ momentum, /* SiLU=*/ false ));
126126 register_module (" block" , this ->block );
127127
128128 this ->sd = StochasticDepth (dropconnect);
@@ -170,7 +170,7 @@ size_t MC_EfficientNetImpl::round_filters(size_t c, double width_mul){
170170// struct{MC_EfficientNetImpl}(nn::Module) -> function{round_repeats}
171171// ----------------------------------------------------------------------
172172size_t MC_EfficientNetImpl::round_repeats (size_t r, double depth_mul){
173- return std::max (1 , int (std::round (r * depth_mul)));
173+ return std::max (1 , int (std::ceil (r * depth_mul)));
174174}
175175
176176
@@ -186,16 +186,16 @@ MC_EfficientNetImpl::MC_EfficientNetImpl(po::variables_map &vm){
186186
187187 // (0.a) Setting for network's config
188188 std::string network = vm[" network" ].as <std::string>();
189- if (network == " B0" ) this ->cfg = {1.0 , 1.0 , 224 , 0.2 };
190- else if (network == " B1" ) this ->cfg = {1.0 , 1.1 , 240 , 0.2 };
191- else if (network == " B2" ) this ->cfg = {1.1 , 1.2 , 260 , 0.3 };
192- else if (network == " B3" ) this ->cfg = {1.2 , 1.4 , 300 , 0.3 };
193- else if (network == " B4" ) this ->cfg = {1.4 , 1.8 , 380 , 0.4 };
194- else if (network == " B5" ) this ->cfg = {1.6 , 2.2 , 456 , 0.4 };
195- else if (network == " B6" ) this ->cfg = {1.8 , 2.6 , 528 , 0.5 };
196- else if (network == " B7" ) this ->cfg = {2.0 , 3.1 , 600 , 0.5 };
197- else if (network == " B8" ) this ->cfg = {2.2 , 3.6 , 672 , 0.5 };
198- else if (network == " L2" ) this ->cfg = {4.3 , 5.3 , 800 , 0.5 };
189+ if (network == " B0" ) this ->cfg = {1.0 , 1.0 , 224 , 0.2 , 1e-5 , 0.1 , 0.2 };
190+ else if (network == " B1" ) this ->cfg = {1.0 , 1.1 , 240 , 0.2 , 1e-5 , 0.1 , 0.2 };
191+ else if (network == " B2" ) this ->cfg = {1.1 , 1.2 , 260 , 0.3 , 1e-5 , 0.1 , 0.2 };
192+ else if (network == " B3" ) this ->cfg = {1.2 , 1.4 , 300 , 0.3 , 1e-5 , 0.1 , 0.2 };
193+ else if (network == " B4" ) this ->cfg = {1.4 , 1.8 , 380 , 0.4 , 1e-5 , 0.1 , 0.2 };
194+ else if (network == " B5" ) this ->cfg = {1.6 , 2.2 , 456 , 0.4 , 0.001 , 0.01 , 0.2 };
195+ else if (network == " B6" ) this ->cfg = {1.8 , 2.6 , 528 , 0.5 , 0.001 , 0.01 , 0.2 };
196+ else if (network == " B7" ) this ->cfg = {2.0 , 3.1 , 600 , 0.5 , 0.001 , 0.01 , 0.2 };
197+ else if (network == " B8" ) this ->cfg = {2.2 , 3.6 , 672 , 0.5 , 0.001 , 0.01 , 0.2 };
198+ else if (network == " L2" ) this ->cfg = {4.3 , 5.3 , 800 , 0.5 , 0.001 , 0.01 , 0.2 };
199199 else {
200200 std::cerr << " Error : The type of network is " << network << ' .' << std::endl;
201201 std::cerr << " Error : Please choose B0, B1, B2, B3, B4, B5, B6, B7, B8 or L2." << std::endl;
@@ -209,7 +209,7 @@ MC_EfficientNetImpl::MC_EfficientNetImpl(po::variables_map &vm){
209209
210210 // (1) Stem layer
211211 stem_nc = this ->round_filters (stem_feature, this ->cfg .width_mul );
212- this ->features ->push_back (Conv2dNormActivation (vm[" nc" ].as <size_t >(), stem_nc, /* kernel_size=*/ 3 , /* stride=*/ 2 , /* padding=*/ 1 , /* groups=*/ 1 , /* SiLU=*/ true ));
212+ this ->features ->push_back (Conv2dNormActivation (vm[" nc" ].as <size_t >(), stem_nc, /* kernel_size=*/ 3 , /* stride=*/ 2 , /* padding=*/ 1 , /* groups=*/ 1 , /* eps= */ this -> cfg . eps , /* momentum= */ this -> cfg . momentum , /* SiLU=*/ true ));
213213
214214 // (2.a) Bone layer
215215 total_blocks = 0 ;
@@ -225,16 +225,16 @@ MC_EfficientNetImpl::MC_EfficientNetImpl(po::variables_map &vm){
225225 out_nc = this ->round_filters (bcfg[i].c , this ->cfg .width_mul );
226226 for (size_t j = 0 ; j < repeats; j++){
227227 stride = (j == 0 ) ? bcfg[i].s : 1 ;
228- dropconnect = this ->cfg .dropout * (double )block_idx / (double )std::max (1 , int (total_blocks));
229- this ->features ->push_back (MBConvImpl (in_nc, out_nc, bcfg[i].k , stride, bcfg[i].exp , dropconnect));
228+ dropconnect = this ->cfg .stochastic_depth_prob * (double )block_idx / (double )std::max (1 , int (total_blocks));
229+ this ->features ->push_back (MBConvImpl (in_nc, out_nc, bcfg[i].k , stride, bcfg[i].exp , this -> cfg . eps , this -> cfg . momentum , dropconnect));
230230 in_nc = out_nc;
231231 block_idx++;
232232 }
233233 }
234234
235235 // (3) Head layer
236236 head_nc = this ->round_filters (head_feature, this ->cfg .width_mul );
237- this ->features ->push_back (Conv2dNormActivation (in_nc, head_nc, /* kernel_size=*/ 1 , /* stride=*/ 1 , /* padding=*/ 0 , /* groups=*/ 1 , /* SiLU=*/ true ));
237+ this ->features ->push_back (Conv2dNormActivation (in_nc, head_nc, /* kernel_size=*/ 1 , /* stride=*/ 1 , /* padding=*/ 0 , /* groups=*/ 1 , /* eps= */ this -> cfg . eps , /* momentum= */ this -> cfg . momentum , /* SiLU=*/ true ));
238238 register_module (" features" , this ->features );
239239
240240 // (4) Global Average Pooling
0 commit comments