@@ -162,6 +162,13 @@ def __init__(self):
162162 logger .info ('default test period: {}' .format (self .test_period ))
163163 logger .info ('=' * 30 )
164164
165+ def set_model_name (self , model_name ):
166+ """
167+ Set model name, eg. "ResNet50", "ResNet101"
168+ """
169+ self .model_name = model_name
170+ logger .info ("Set model_name to {}." .format (model_name ))
171+
165172 def set_use_quant (self , quant ):
166173 """
167174 Whether to use quantization
@@ -203,8 +210,9 @@ def set_val_targets(self, targets):
203210 def set_train_batch_size (self , batch_size ):
204211 self .train_batch_size = batch_size
205212 self .global_train_batch_size = batch_size * self .num_trainers
206- logger .info ("Set train batch size per trainer to {}." .format (
207- batch_size ))
213+ logger .info ("Set train batch size per trainer to {}, global "
214+ "train batch size to {}." .format (
215+ batch_size , self .global_train_batch_size ))
208216
209217 def set_log_period (self , period ):
210218 self .log_period = period
@@ -219,8 +227,8 @@ def set_lr_decay_factor(self, factor):
219227 logger .info ("Set lr decay factor to {}." .format (factor ))
220228
221229 def set_step_boundaries (self , boundaries ):
222- if not isinstance (boundaries , list ):
223- raise ValueError ("The parameter must be of type list." )
230+ if not isinstance (boundaries , ( tuple , list ) ):
231+ raise ValueError ("The parameter must be of type tuple or list." )
224232 self .lr_steps = boundaries
225233 logger .info ("Set step boundaries to {}." .format (boundaries ))
226234
0 commit comments