@@ -304,10 +304,50 @@ def test_transpiler(self):
304
304
# TODO(typhoonzero): test clipping and L2Decay ops are removed from trainer
305
305
306
306
307
- # FIXME(typhoonzero): need to add test for async case:
308
- # see https://github.com/PaddlePaddle/Paddle/issues/11691
309
- class TestAsyncSGD (TranspilerTest ):
310
- pass
307
+ class TestL2DecayWithPiecewise (TranspilerTest ):
308
+ def net_conf (self ):
309
+ x = fluid .layers .data (name = 'x' , shape = [1000 ], dtype = 'float32' )
310
+ y_predict = fluid .layers .fc (input = x ,
311
+ size = 1000 ,
312
+ act = None ,
313
+ param_attr = fluid .ParamAttr (name = 'fc_w' ),
314
+ bias_attr = fluid .ParamAttr (name = 'fc_b' ))
315
+ y = fluid .layers .data (name = 'y' , shape = [1 ], dtype = 'float32' )
316
+ cost = fluid .layers .square_error_cost (input = y_predict , label = y )
317
+ avg_cost = fluid .layers .mean (cost )
318
+ base_lr = 1.0
319
+ bd = [1 , 10 , 20 , 30 ]
320
+ lr = [base_lr * (0.1 ** i ) for i in range (len (bd ) + 1 )]
321
+ sgd_optimizer = fluid .optimizer .Momentum (
322
+ learning_rate = fluid .layers .piecewise_decay (
323
+ boundaries = bd , values = lr ),
324
+ momentum = 0.9 ,
325
+ regularization = fluid .regularizer .L2Decay (1e-4 ))
326
+ sgd_optimizer .minimize (avg_cost )
327
+ return
328
+
329
+ def test_transpiler (self ):
330
+ pserver , startup = self .get_pserver (self .pserver1_ep )
331
+ trainer = self .get_trainer ()
332
+
333
+ self .assertEqual (len (pserver .blocks ), 9 )
334
+ self .assertEqual ([op .type for op in pserver .blocks [1 ].ops ], [
335
+ "increment" , "cast" , "fill_constant" , "fill_constant" , "less_than" ,
336
+ "logical_not" , "conditional_block" , "fill_constant" ,
337
+ "fill_constant" , "less_than" , "logical_not" , "logical_and" ,
338
+ "logical_and" , "conditional_block" , "fill_constant" ,
339
+ "fill_constant" , "less_than" , "logical_not" , "logical_and" ,
340
+ "logical_and" , "conditional_block" , "fill_constant" ,
341
+ "fill_constant" , "less_than" , "logical_not" , "logical_and" ,
342
+ "logical_and" , "conditional_block" , "fill_constant" ,
343
+ "conditional_block"
344
+ ])
345
+ self .assertEqual (
346
+ [op .type for op in pserver .blocks [7 ].ops ],
347
+ ["sum" , "scale" , "scale" , "elementwise_add" , "momentum" ])
348
+ self .assertEqual (
349
+ [op .type for op in pserver .blocks [8 ].ops ],
350
+ ["sum" , "scale" , "scale" , "elementwise_add" , "momentum" ])
311
351
312
352
313
353
if __name__ == "__main__" :
0 commit comments