Skip to content

Commit 7be63c3

Browse files
authored
Add files via upload
1 parent 1efdec2 commit 7be63c3

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

Multiclass_Classification/EfficientNet/src/networks.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -283,15 +283,14 @@ void weights_init(nn::Module &m){
283283
auto p = m.named_parameters(false);
284284
auto w = p.find("weight");
285285
auto b = p.find("bias");
286-
if (w != nullptr) nn::init::kaiming_normal_(*w, /*a=*/0.0, torch::kFanOut);
286+
if (w != nullptr) nn::init::normal_(*w, /*mean=*/0.0, /*std=*/0.01);
287287
if (b != nullptr) nn::init::constant_(*b, /*bias=*/0.0);
288288
}
289289
else if ((typeid(m) == typeid(nn::Linear)) || (typeid(m) == typeid(nn::LinearImpl))){
290290
auto p = m.named_parameters(false);
291291
auto w = p.find("weight");
292292
auto b = p.find("bias");
293-
double bound = 1.0 / std::sqrt((double)(*w).size(0));
294-
if (w != nullptr) nn::init::uniform_(*w, -bound, bound);
293+
if (w != nullptr) nn::init::normal_(*w, /*mean=*/0.0, /*std=*/0.01);
295294
if (b != nullptr) nn::init::constant_(*b, /*bias=*/0.0);
296295
}
297296
else if ((typeid(m) == typeid(nn::BatchNorm2d)) || (typeid(m) == typeid(nn::BatchNorm2dImpl))){

0 commit comments

Comments
 (0)