-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathbn_adaptation.py
More file actions
21 lines (17 loc) · 861 Bytes
/
bn_adaptation.py
File metadata and controls
21 lines (17 loc) · 861 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch.nn as nn
from models.cbna_vgg.networks.vgg_encoder import MyBatchNorm2d as BN_VGG
from models.cbna.networks.resnet_encoder import MyBatchNorm2d as BN_ResNet
class BNAdaptation(object):
def __init__(self):
pass
def process(self, model, momentum):
# Check for architecture of encoder, because layers are accessed differently for different architectures
if model._get_name() == 'UBNAVGG':
for module in model.common.encoder.encoder.features.modules():
if type(module) == nn.BatchNorm2d or type(module) == BN_VGG:
module.momentum = momentum
else:
for module in model.common.encoder.encoder.modules():
if type(module) == nn.BatchNorm2d or type(module) == BN_ResNet:
module.momentum = momentum
return model