1
- from paddle .v2 .framework .layer_helper import LayerHelper
1
+ from paddle .v2 .framework .layer_helper import LayerHelper , unique_name
2
2
import paddle .v2 .framework .core as core
3
- from paddle .v2 .framework .framework import OpProtoHolder , Variable
3
+ from paddle .v2 .framework .framework import OpProtoHolder , Variable , Program
4
4
import re
5
5
6
6
__all__ = [
7
- 'fc' , 'data' , 'cross_entropy' , 'conv2d' , 'pool2d' , 'embedding' , 'concat'
7
+ 'fc' , 'data' , 'cross_entropy' , 'conv2d' , 'pool2d' , 'embedding' , 'concat' ,
8
+ 'StaticRNN'
8
9
]
9
10
10
11
@@ -26,7 +27,9 @@ def fc(input,
26
27
mul_results = []
27
28
for input_var , param_attr in helper .iter_inputs_and_params ():
28
29
input_shape = input_var .shape
29
- param_shape = list (input_shape [num_flatten_dims :]) + [size ]
30
+ param_shape = [
31
+ reduce (lambda a , b : a * b , input_shape [num_flatten_dims :], 1 )
32
+ ] + [size ]
30
33
31
34
w = helper .create_parameter (
32
35
attr = param_attr , shape = param_shape , dtype = dtype )
@@ -38,10 +41,8 @@ def fc(input,
38
41
"Y" : w ,
39
42
},
40
43
outputs = {"Out" : tmp },
41
- attrs = {
42
- 'x_num_col_dims' : num_flatten_dims ,
43
- 'y_num_col_dims' : len (input_shape ) - num_flatten_dims
44
- })
44
+ attrs = {'x_num_col_dims' : num_flatten_dims ,
45
+ 'y_num_col_dims' : 1 })
45
46
mul_results .append (tmp )
46
47
47
48
# sum
@@ -273,3 +274,170 @@ def pool2d(input,
273
274
})
274
275
275
276
return pool_out
277
+
278
+
279
+ class BlockGuard (object ):
280
+ """
281
+ BlockGuard used to create sub-block in program by using Python `with`
282
+ keyword.
283
+ """
284
+
285
+ def __init__ (self , program ):
286
+ if not isinstance (program , Program ):
287
+ raise TypeError ("BlockGuard takes a program" )
288
+ self .program = program
289
+
290
+ def __enter__ (self ):
291
+ self .program .create_block ()
292
+
293
+ def __exit__ (self , exc_type , exc_val , exc_tb ):
294
+ self .program .rollback ()
295
+ if exc_type is not None :
296
+ return False # re-raise exception
297
+ return True
298
+
299
+
300
+ class StaticRNNGuard (BlockGuard ):
301
+ def __init__ (self , rnn ):
302
+ if not isinstance (rnn , StaticRNN ):
303
+ raise TypeError ("StaticRNNGuard takes an StaticRNN" )
304
+ super (StaticRNNGuard , self ).__init__ (rnn .helper .program )
305
+ self .rnn = rnn
306
+
307
+ def __enter__ (self ):
308
+ self .rnn .status = StaticRNN .IN_RNN_BLOCK
309
+ return super (StaticRNNGuard , self ).__enter__ ()
310
+
311
+ def __exit__ (self , exc_type , exc_val , exc_tb ):
312
+ self .rnn .status = StaticRNN .AFTER_RNN_BLOCK
313
+ self .rnn .complete_rnn_op ()
314
+ return super (StaticRNNGuard , self ).__exit__ (exc_type , exc_val , exc_tb )
315
+
316
+
317
+ class StaticRNNMemoryLink (object ):
318
+ """
319
+ :param init: the initial variable for Memory
320
+ :type init: Variable
321
+ :param pre_mem: the memory variable in previous time step
322
+ :type pre_mem: Variable
323
+ :param mem: the memory variable in current time step
324
+ :type mem: Variable
325
+ """
326
+
327
+ def __init__ (self , init , pre_mem , mem = None ):
328
+ self .init = init
329
+ self .pre_mem = pre_mem
330
+ self .mem = mem
331
+
332
+
333
+ class StaticRNN (object ):
334
+ BEFORE_RNN_BLOCK = 0
335
+ IN_RNN_BLOCK = 1
336
+ AFTER_RNN_BLOCK = 2
337
+
338
+ def __init__ (self , name = None , program = None ):
339
+ self .helper = LayerHelper ("static_rnn" , name = name , program = program )
340
+ self .memories = {} # memory map, from pre_mem.name --> MemoryLink
341
+ self .inputs = [] # input variable list in current block
342
+ self .outputs = [] # output variable list in parent block
343
+ self .status = StaticRNN .BEFORE_RNN_BLOCK # status flag.
344
+ # sequence length, since it is a static RNN, sequence length are fixed.
345
+ self .seq_len = None
346
+
347
+ def step (self ):
348
+ return StaticRNNGuard (self )
349
+
350
+ def _assert_in_rnn_block_ (self , method ):
351
+ if self .status != StaticRNN .IN_RNN_BLOCK :
352
+ raise ValueError ("You must invoke {0} in rnn block" .format (method ))
353
+
354
+ def memory (self , init = None , shape = None , dtype = None , init_value = 0 ):
355
+ self ._assert_in_rnn_block_ ('memory' )
356
+ if init is None :
357
+ if shape is None or dtype is None :
358
+ raise ValueError (
359
+ "if init is None, memory at least need shape and dtype" )
360
+ parent_block = self .parent_block ()
361
+ var_name = unique_name ("@" .join ([self .helper .name , "memory_boot" ]))
362
+ boot_var = parent_block .create_var (
363
+ name = var_name , shape = shape , dtype = dtype , persistable = False )
364
+
365
+ parent_block .append_op (
366
+ type = "fill_constant" ,
367
+ inputs = {},
368
+ outputs = {'Out' : [boot_var ]},
369
+ attrs = {
370
+ 'value' : init_value ,
371
+ 'shape' : boot_var .shape ,
372
+ 'data_type' : boot_var .data_type
373
+ })
374
+
375
+ return self .memory (init = boot_var )
376
+ else :
377
+ pre_mem = self .helper .create_variable (
378
+ name = unique_name ("@" .join ([self .helper .name , "mem" ])),
379
+ dtype = init .data_type ,
380
+ shape = init .shape )
381
+ self .memories [pre_mem .name ] = StaticRNNMemoryLink (
382
+ init = init , pre_mem = pre_mem )
383
+ return pre_mem
384
+
385
+ def step_input (self , x ):
386
+ self ._assert_in_rnn_block_ ('step_input' )
387
+ if not isinstance (x , Variable ):
388
+ raise TypeError ("step input takes a Variable" )
389
+ if self .seq_len is None :
390
+ self .seq_len = x .shape [1 ]
391
+ elif self .seq_len != x .shape [1 ]:
392
+ raise ValueError ("Static RNN only take fix seq_len input" )
393
+
394
+ ipt = self .helper .create_variable (
395
+ name = x .name ,
396
+ dtype = x .data_type ,
397
+ shape = [- 1 ] + list (x .shape [2 :]),
398
+ type = x .type )
399
+ self .inputs .append (ipt )
400
+ return ipt
401
+
402
+ def step_output (self , o ):
403
+ self ._assert_in_rnn_block_ ('step_output' )
404
+ if not isinstance (o , Variable ):
405
+ raise TypeError ("step output takes a Variable" )
406
+
407
+ out_var = self .parent_block ().create_var (
408
+ name = o .name ,
409
+ shape = [- 1 , self .seq_len ] + list (o .shape [1 :]),
410
+ dtype = o .data_type )
411
+
412
+ self .outputs .append (out_var )
413
+
414
+ def output (self , * outputs ):
415
+ for each in outputs :
416
+ self .step_output (each )
417
+
418
+ def update_memory (self , mem , var ):
419
+ if not isinstance (mem , Variable ) or not isinstance (var , Variable ):
420
+ raise TypeError ("update memory should take variables" )
421
+ self .memories [mem .name ].mem = var
422
+
423
+ def parent_block (self ):
424
+ prog = self .helper .program
425
+ parent_idx = prog .current_block ().parent_idx
426
+ assert parent_idx >= 0
427
+ parent_block = prog .block (parent_idx )
428
+ return parent_block
429
+
430
+ def __call__ (self , * args , ** kwargs ):
431
+ if self .status != StaticRNN .AFTER_RNN_BLOCK :
432
+ raise ValueError ("RNN output can only be retrieved after rnn block" )
433
+ if len (self .outputs ) == 0 :
434
+ raise ValueError ("RNN has no output" )
435
+ elif len (self .outputs ) == 1 :
436
+ return self .outputs [0 ]
437
+ else :
438
+ return self .outputs
439
+
440
+ def complete_rnn_op (self ):
441
+ # TODO(yuyang18): Create RNN Op here.
442
+ # Implement this method after RNN op complete.
443
+ pass
0 commit comments