File tree Expand file tree Collapse file tree 1 file changed +2
-3
lines changed
Multiclass_Classification/EfficientNet/src Expand file tree Collapse file tree 1 file changed +2
-3
lines changed Original file line number Diff line number Diff line change @@ -283,15 +283,14 @@ void weights_init(nn::Module &m){
283283 auto p = m.named_parameters (false );
284284 auto w = p.find (" weight" );
285285 auto b = p.find (" bias" );
286- if (w != nullptr ) nn::init::kaiming_normal_ (*w, /* a =*/ 0.0 , torch:: kFanOut );
286+ if (w != nullptr ) nn::init::normal_ (*w, /* mean =*/ 0.0 , /* std= */ 0.01 );
287287 if (b != nullptr ) nn::init::constant_ (*b, /* bias=*/ 0.0 );
288288 }
289289 else if ((typeid (m) == typeid (nn::Linear)) || (typeid (m) == typeid (nn::LinearImpl))){
290290 auto p = m.named_parameters (false );
291291 auto w = p.find (" weight" );
292292 auto b = p.find (" bias" );
293- double bound = 1.0 / std::sqrt ((double )(*w).size (0 ));
294- if (w != nullptr ) nn::init::uniform_ (*w, -bound, bound);
293+ if (w != nullptr ) nn::init::normal_ (*w, /* mean=*/ 0.0 , /* std=*/ 0.01 );
295294 if (b != nullptr ) nn::init::constant_ (*b, /* bias=*/ 0.0 );
296295 }
297296 else if ((typeid (m) == typeid (nn::BatchNorm2d)) || (typeid (m) == typeid (nn::BatchNorm2dImpl))){
You can’t perform that action at this time.
0 commit comments