Skip to content

Commit 1efdec2

Browse files
authored
Add files via upload
1 parent 7a8c5be commit 1efdec2

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

Multiclass_Classification/EfficientNet/src/networks.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <algorithm>
44
#include <typeinfo>
55
#include <cstdlib>
6+
#include <cmath>
67
// For External Library
78
#include <torch/torch.h>
89
// For Original Header
@@ -282,14 +283,15 @@ void weights_init(nn::Module &m){
282283
auto p = m.named_parameters(false);
283284
auto w = p.find("weight");
284285
auto b = p.find("bias");
285-
if (w != nullptr) nn::init::normal_(*w, /*mean=*/0.0, /*std=*/0.01);
286+
if (w != nullptr) nn::init::kaiming_normal_(*w, /*a=*/0.0, torch::kFanOut);
286287
if (b != nullptr) nn::init::constant_(*b, /*bias=*/0.0);
287288
}
288289
else if ((typeid(m) == typeid(nn::Linear)) || (typeid(m) == typeid(nn::LinearImpl))){
289290
auto p = m.named_parameters(false);
290291
auto w = p.find("weight");
291292
auto b = p.find("bias");
292-
if (w != nullptr) nn::init::normal_(*w, /*mean=*/0.0, /*std=*/0.01);
293+
double bound = 1.0 / std::sqrt((double)(*w).size(0));
294+
if (w != nullptr) nn::init::uniform_(*w, -bound, bound);
293295
if (b != nullptr) nn::init::constant_(*b, /*bias=*/0.0);
294296
}
295297
else if ((typeid(m) == typeid(nn::BatchNorm2d)) || (typeid(m) == typeid(nn::BatchNorm2dImpl))){

0 commit comments

Comments
 (0)