File tree Expand file tree Collapse file tree 1 file changed +4
-2
lines changed
Multiclass_Classification/EfficientNet/src Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Original file line number Diff line number Diff line change 33#include < algorithm>
44#include < typeinfo>
55#include < cstdlib>
6+ #include < cmath>
67// For External Library
78#include < torch/torch.h>
89// For Original Header
@@ -282,14 +283,15 @@ void weights_init(nn::Module &m){
282283 auto p = m.named_parameters (false );
283284 auto w = p.find (" weight" );
284285 auto b = p.find (" bias" );
285- if (w != nullptr ) nn::init::normal_ (*w, /* mean =*/ 0.0 , /* std= */ 0.01 );
286+ if (w != nullptr ) nn::init::kaiming_normal_ (*w, /* a =*/ 0.0 , torch:: kFanOut );
286287 if (b != nullptr ) nn::init::constant_ (*b, /* bias=*/ 0.0 );
287288 }
288289 else if ((typeid (m) == typeid (nn::Linear)) || (typeid (m) == typeid (nn::LinearImpl))){
289290 auto p = m.named_parameters (false );
290291 auto w = p.find (" weight" );
291292 auto b = p.find (" bias" );
292- if (w != nullptr ) nn::init::normal_ (*w, /* mean=*/ 0.0 , /* std=*/ 0.01 );
293+ double bound = 1.0 / std::sqrt ((double )(*w).size (0 ));
294+ if (w != nullptr ) nn::init::uniform_ (*w, -bound, bound);
293295 if (b != nullptr ) nn::init::constant_ (*b, /* bias=*/ 0.0 );
294296 }
295297 else if ((typeid (m) == typeid (nn::BatchNorm2d)) || (typeid (m) == typeid (nn::BatchNorm2dImpl))){
You can’t perform that action at this time.
0 commit comments