Skip to content

Commit ac1c8fc

Browse files
committed
Update torch_pretrained_net.py
1 parent 55d2edc commit ac1c8fc

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

pymic/net/cls/torch_pretrained_net.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def __init__(self, params):
7575
def get_parameters_to_update(self):
7676
if(self.update_mode == "all"):
7777
return self.net.parameters()
78-
elif(self.update_layers == "last"):
78+
elif(self.update_mode == "last"):
7979
params = self.net.fc.parameters()
8080
if(self.in_chns !=3):
8181
# combining the two iterables into a single one
@@ -119,7 +119,7 @@ def get_parameters_to_update(self):
119119
params = self.net.classifier[-1].parameters()
120120
if(self.in_chns !=3):
121121
params = itertools.chain()
122-
for pram in [self.net.classifier[-1].parameters(), self.net.net.features[0].parameters()]:
122+
for pram in [self.net.classifier[-1].parameters(), self.net.features[0].parameters()]:
123123
params = itertools.chain(params, pram)
124124
return params
125125
else:
@@ -138,7 +138,7 @@ class MobileNetV2(BuiltInNet):
138138
as well as the first layer when `input_chns` is not 3.
139139
"""
140140
def __init__(self, params):
141-
super(MobileNetV2, self).__init__()
141+
super(MobileNetV2, self).__init__(params)
142142
self.net = models.mobilenet_v2(pretrained = self.pretrain)
143143

144144
# replace the last layer
@@ -157,7 +157,7 @@ def get_parameters_to_update(self):
157157
params = self.net.classifier[-1].parameters()
158158
if(self.in_chns !=3):
159159
params = itertools.chain()
160-
for pram in [self.net.classifier[-1].parameters(), self.net.net.features[0][0].parameters()]:
160+
for pram in [self.net.classifier[-1].parameters(), self.net.features[0][0].parameters()]:
161161
params = itertools.chain(params, pram)
162162
return params
163163
else:

0 commit comments

Comments
 (0)