114
114
# Initialize global variables. We use this function so that we can
115
115
# call parse_config() multiple times
116
116
def init_config_environment (
117
- g_default_momentum = 0. ,
118
- g_default_decay_rate = 0. ,
117
+ g_default_momentum = None ,
118
+ g_default_decay_rate = None ,
119
119
g_default_initial_mean = 0. ,
120
120
g_default_initial_std = 0.01 ,
121
- g_default_num_batches_regularization = 1 ,
121
+ g_default_num_batches_regularization = None ,
122
122
g_default_initial_strategy = 0 ,
123
123
g_default_initial_smart = False ,
124
- g_default_gradient_clipping_threshold = 0. ,
125
- g_default_device = - 1 ,
124
+ g_default_gradient_clipping_threshold = None ,
125
+ g_default_device = None ,
126
126
g_default_update_hooks = None ,
127
127
g_default_compact_func = None ,
128
128
@@ -1099,12 +1099,12 @@ def Evaluator(
1099
1099
inputs ,
1100
1100
chunk_scheme = None ,
1101
1101
num_chunk_types = None ,
1102
- classification_threshold = 0.5 ,
1103
- positive_label = - 1 ,
1104
- dict_file = "" ,
1105
- result_file = "" ,
1106
- num_results = 1 ,
1107
- delimited = True ,
1102
+ classification_threshold = None ,
1103
+ positive_label = None ,
1104
+ dict_file = None ,
1105
+ result_file = None ,
1106
+ num_results = None ,
1107
+ delimited = None ,
1108
1108
):
1109
1109
evaluator = g_config .model_config .evaluators .add ()
1110
1110
evaluator .type = type
@@ -1120,12 +1120,19 @@ def Evaluator(
1120
1120
evaluator .num_chunk_types = num_chunk_types
1121
1121
g_current_submodel .evaluator_names .append (evaluator .name )
1122
1122
1123
- evaluator .classification_threshold = classification_threshold
1124
- evaluator .positive_label = positive_label
1125
- evaluator .dict_file = dict_file
1126
- evaluator .result_file = result_file
1127
- evaluator .num_results = num_results
1128
- evaluator .delimited = delimited
1123
+ if classification_threshold is not None :
1124
+ evaluator .classification_threshold = classification_threshold
1125
+ if positive_label is not None :
1126
+ evaluator .positive_label = positive_label
1127
+ if dict_file is not None :
1128
+ evaluator .dict_file = dict_file
1129
+
1130
+ if result_file is not None :
1131
+ evaluator .result_file = result_file
1132
+ if num_results is not None :
1133
+ evaluator .num_results = num_results
1134
+ if delimited is not None :
1135
+ evaluator .delimited = delimited
1129
1136
1130
1137
class LayerBase (object ):
1131
1138
def __init__ (
@@ -1137,7 +1144,7 @@ def __init__(
1137
1144
device = None ,
1138
1145
active_type = "" ,
1139
1146
drop_rate = 0. ,
1140
- coeff = 1. ):
1147
+ coeff = None ):
1141
1148
config_assert ('@' not in name ,
1142
1149
"layer name: %s contain special character @" % name )
1143
1150
global g_current_submodel
@@ -1155,18 +1162,20 @@ def __init__(
1155
1162
self .inputs = [self .inputs ]
1156
1163
1157
1164
self .config = g_config .model_config .layers .add ()
1165
+ assert isinstance (self .config , LayerConfig )
1158
1166
self .config .name = name
1159
1167
self .config .type = type
1160
1168
self .config .active_type = active_type
1161
- self .config .coeff = coeff
1169
+ if coeff is not None :
1170
+ self .config .coeff = float (coeff )
1162
1171
if size != 0 :
1163
1172
self .config .size = size
1164
1173
if drop_rate != 0 :
1165
1174
self .config .drop_rate = drop_rate
1166
1175
1167
1176
if device is not None :
1168
1177
self .config .device = device
1169
- else :
1178
+ elif g_default_device is not None :
1170
1179
self .config .device = g_default_device
1171
1180
1172
1181
for input_index in xrange (len (self .inputs )):
@@ -1236,10 +1245,12 @@ def create_bias_parameter(
1236
1245
if bias .parameter_name is None :
1237
1246
bias .parameter_name = gen_bias_parameter_name (self .config .name )
1238
1247
if bias .parameter_name not in g_parameter_map :
1248
+ assert isinstance (self .config , LayerConfig )
1249
+
1239
1250
Parameter (
1240
1251
bias .parameter_name ,
1241
1252
size ,
1242
- self .config .device ,
1253
+ self .config .device if self . config . HasField ( 'device' ) else None ,
1243
1254
dims ,
1244
1255
bias .learning_rate ,
1245
1256
bias .momentum ,
@@ -1265,7 +1276,7 @@ def create_input_parameter(
1265
1276
input_index ,
1266
1277
size ,
1267
1278
dims = None ,
1268
- sparse = False ,
1279
+ sparse = None ,
1269
1280
format = "csr" ):
1270
1281
if dims is None :
1271
1282
# TODO(yuyang18): print warning and callstack here!
@@ -1293,7 +1304,7 @@ def create_input_parameter(
1293
1304
Parameter (
1294
1305
input_config .parameter_name ,
1295
1306
size ,
1296
- self .config .device ,
1307
+ self .config .device if self . config . HasField ( "device" ) else None ,
1297
1308
dims ,
1298
1309
input_config .learning_rate ,
1299
1310
input_config .momentum ,
@@ -1353,6 +1364,8 @@ def __init__(
1353
1364
1354
1365
if sparse :
1355
1366
psize = self .inputs [input_index ].nnz
1367
+ else :
1368
+ sparse = None
1356
1369
1357
1370
self .create_input_parameter (input_index , psize , dims , sparse , format )
1358
1371
self .create_bias_parameter (bias , self .config .size )
@@ -2836,27 +2849,44 @@ def Parameter(
2836
2849
para = g_config .model_config .parameters .add ()
2837
2850
para .name = name
2838
2851
para .size = size
2839
- para .device = device
2840
- para .dims .extend (dims );
2841
- para .learning_rate = default (learning_rate , 1. )
2842
- para .momentum = default (momentum , g_default_momentum )
2852
+ if device is not None :
2853
+ para .device = int (device )
2854
+ para .dims .extend (dims )
2855
+
2856
+ if learning_rate is not None :
2857
+ para .learning_rate = float (learning_rate )
2858
+
2859
+ momentum = default (momentum , g_default_momentum )
2860
+ if momentum is not None :
2861
+ para .momentum = float (momentum )
2862
+
2843
2863
config_assert (not momentum or not decay_rate_l1 ,
2844
2864
"momentum and decay_rate_l1 cannot both be non-zero" )
2845
- para .decay_rate = default (decay_rate , g_default_decay_rate )
2865
+
2866
+ decay_rate = default (decay_rate , g_default_decay_rate )
2867
+ if decay_rate is not None :
2868
+ para .decay_rate = decay_rate
2869
+
2846
2870
if decay_rate_l1 is not None :
2847
2871
para .decay_rate_l1 = decay_rate_l1
2848
2872
para .initial_std = default (initial_std , g_default_initial_std )
2849
2873
para .initial_mean = default (initial_mean , g_default_initial_mean )
2850
- para .num_batches_regularization = default (
2874
+
2875
+ num_batches_regularization = default (
2851
2876
num_batches_regularization , g_default_num_batches_regularization )
2877
+ if num_batches_regularization is not None :
2878
+ para .num_batches_regularization = int (num_batches_regularization )
2879
+
2852
2880
if sparse_remote_update is not None :
2853
2881
para .sparse_remote_update = sparse_remote_update
2854
2882
if sparse_remote_update :
2855
2883
g_config .opt_config .use_sparse_remote_updater = True
2856
2884
if sparse_update is not None :
2857
2885
para .sparse_update = sparse_update
2858
- para .gradient_clipping_threshold = default (
2859
- gradient_clipping_threshold , g_default_gradient_clipping_threshold );
2886
+ gradient_clipping_threshold = default (
2887
+ gradient_clipping_threshold , g_default_gradient_clipping_threshold )
2888
+ if gradient_clipping_threshold is not None :
2889
+ para .gradient_clipping_threshold = gradient_clipping_threshold
2860
2890
para .initial_strategy = default (initial_strategy , g_default_initial_strategy )
2861
2891
para .initial_smart = default (initial_smart , g_default_initial_smart )
2862
2892
if para .initial_smart :
@@ -2869,15 +2899,19 @@ def Parameter(
2869
2899
para .initial_std = 1. / math .sqrt (para .size )
2870
2900
if g_default_compact_func is not None :
2871
2901
sparse , format , need_compact = g_default_compact_func (para .name )
2872
- para .is_sparse = default (sparse , False )
2873
- para .format = default (format , "" )
2874
- para .need_compact = default (need_compact , False )
2902
+
2903
+ if sparse is not None :
2904
+ para .is_sparse = sparse
2905
+ if format is not None :
2906
+ para .format = format
2907
+ if need_compact is not None :
2908
+ para .need_compact = need_compact
2875
2909
if is_static is not None :
2876
2910
para .is_static = is_static
2877
2911
config_assert (not para .sparse_remote_update or not para .is_static ,
2878
2912
"sparse_remote_update and is_static cannot both be true" )
2879
-
2880
- para .is_shared = default ( is_shared , False )
2913
+ if is_shared is not None :
2914
+ para .is_shared = is_shared
2881
2915
2882
2916
update_hooks = default (update_hooks , g_default_update_hooks )
2883
2917
0 commit comments