Skip to content

Commit 1b62f7f

Browse files
authored
fix res100 bug (#86)
* support res100
1 parent bb786d8 commit 1b62f7f

File tree

4 files changed

+20
-7
lines changed

4 files changed

+20
-7
lines changed

plsc/entry.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,13 @@ def __init__(self):
162162
logger.info('default test period: {}'.format(self.test_period))
163163
logger.info('=' * 30)
164164

165+
def set_model_name(self, model_name):
166+
"""
167+
Set model name, eg. "ResNet50", "ResNet101"
168+
"""
169+
self.model_name = model_name
170+
logger.info("Set model_name to {}.".format(model_name))
171+
165172
def set_use_quant(self, quant):
166173
"""
167174
Whether to use quantization
@@ -203,8 +210,9 @@ def set_val_targets(self, targets):
203210
def set_train_batch_size(self, batch_size):
204211
self.train_batch_size = batch_size
205212
self.global_train_batch_size = batch_size * self.num_trainers
206-
logger.info("Set train batch size per trainer to {}.".format(
207-
batch_size))
213+
logger.info("Set train batch size per trainer to {}, global "
214+
"train batch size to {}.".format(
215+
batch_size, self.global_train_batch_size))
208216

209217
def set_log_period(self, period):
210218
self.log_period = period
@@ -219,8 +227,8 @@ def set_lr_decay_factor(self, factor):
219227
logger.info("Set lr decay factor to {}.".format(factor))
220228

221229
def set_step_boundaries(self, boundaries):
222-
if not isinstance(boundaries, list):
223-
raise ValueError("The parameter must be of type list.")
230+
if not isinstance(boundaries, (tuple, list)):
231+
raise ValueError("The parameter must be of type tuple or list.")
224232
self.lr_steps = boundaries
225233
logger.info("Set step boundaries to {}.".format(boundaries))
226234

plsc/models/base_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import paddle
1818
from paddle.utils import unique_name
19-
from paddle.distributed.fleet.utils.plsc_util import class_center_sample
19+
from paddle.distributed.fleet.utils import class_center_sample
2020

2121
from . import dist_algo
2222

plsc/models/dist_algo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import paddle.nn as nn
2323
import paddle.utils.unique_name as unique_name
2424
from paddle.optimizer import Optimizer
25-
from paddle.distributed.fleet.utils.plsc_util import class_center_sample
25+
from paddle.distributed.fleet.utils import class_center_sample
2626
from ..utils.fp16_utils import rewrite_program, update_role_var_grad
2727
from ..utils.fp16_utils import update_loss_scaling, move_optimize_ops_back, check_finite_and_unscale
2828
from ..utils.fp16_lists import AutoMixedPrecisionLists

plsc/models/resnet.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from .base_model import BaseModel
1818

19-
__all__ = ["ResNet", "ResNet50", "ResNet101", "ResNet152"]
19+
__all__ = ["ResNet", "ResNet50", "ResNet100", "ResNet101", "ResNet152"]
2020

2121

2222
class ResNet(BaseModel):
@@ -184,6 +184,11 @@ def ResNet50(emb_dim=512):
184184
return model
185185

186186

187+
def ResNet100(emb_dim=512):
188+
model = ResNet(layers=100, emb_dim=emb_dim)
189+
return model
190+
191+
187192
def ResNet101(emb_dim=512):
188193
model = ResNet(layers=101, emb_dim=emb_dim)
189194
return model

0 commit comments

Comments
 (0)