@@ -364,6 +364,8 @@ class TransformerDecoderLayer(nn.Layer):
364
364
It contains multiheadattention and some linear layers.
365
365
"""
366
366
367
+ Cache = collections .namedtuple ("Cache" , ["kv" ])
368
+
367
369
def __init__ (self ,
368
370
d_model ,
369
371
nhead ,
@@ -375,7 +377,8 @@ def __init__(self,
375
377
normalize_before = True ,
376
378
weight_attr = None ,
377
379
bias_attr = None ,
378
- topo = None ):
380
+ topo = None ,
381
+ ** kwargs ):
379
382
self ._config = locals ()
380
383
self ._config .pop ("self" )
381
384
self ._config .pop ("__class__" , None ) # py3
@@ -388,45 +391,94 @@ def __init__(self,
388
391
weight_attrs = _convert_param_attr_to_list (weight_attr , 3 )
389
392
bias_attrs = _convert_param_attr_to_list (bias_attr , 3 )
390
393
391
- self .self_attn = MultiHeadAttention (
392
- d_model ,
393
- nhead ,
394
- dropout = attn_dropout ,
395
- weight_attr = weight_attrs [0 ],
396
- bias_attr = bias_attrs [0 ],
397
- topo = topo )
398
- if topo is None or topo .mp_info .size == 1 :
399
- self .linear1 = nn .Linear (
394
+ self ._fuse = kwargs .get ('fuse' , False )
395
+ if self ._fuse :
396
+ nranks , ring_id = 1 , - 1
397
+ if topo is not None and topo .mp_info .size > 1 :
398
+ nranks = topo .mp_info .size
399
+ ring_id = 0
400
+ self .self_attn = incubate .nn .FusedMultiHeadAttention (
400
401
d_model ,
401
- dim_feedforward ,
402
- weight_attrs [2 ],
403
- bias_attr = bias_attrs [2 ])
404
- self .linear2 = nn .Linear (
405
- dim_feedforward ,
402
+ nhead ,
403
+ dropout_rate = dropout ,
404
+ attn_dropout_rate = attn_dropout ,
405
+ normalize_before = normalize_before ,
406
+ qkv_weight_attr = weight_attrs [0 ],
407
+ qkv_bias_attr = bias_attrs [0 ],
408
+ linear_weight_attr = weight_attrs [0 ],
409
+ linear_bias_attr = bias_attrs [0 ],
410
+ epsilon = 1e-5 ,
411
+ nranks = nranks ,
412
+ ring_id = ring_id )
413
+ self .ffn = incubate .nn .FusedFeedForward (
406
414
d_model ,
407
- weight_attrs [2 ],
408
- bias_attr = bias_attrs [2 ])
415
+ dim_feedforward ,
416
+ dropout_rate = act_dropout ,
417
+ epsilon = 1e-5 ,
418
+ activation = activation ,
419
+ normalize_before = normalize_before ,
420
+ act_dropout_rate = 0.0 ,
421
+ linear1_weight_attr = weight_attrs [2 ],
422
+ linear1_bias_attr = bias_attrs [2 ],
423
+ linear2_weight_attr = weight_attrs [2 ],
424
+ linear2_bias_attr = bias_attrs [2 ],
425
+ nranks = nranks ,
426
+ ring_id = ring_id )
409
427
else :
410
- self .linear1 = paddlenlp .ops .ColumnParallelLiner (
411
- (d_model , dim_feedforward ),
412
- topo .mp_info .size ,
413
- gather_out = False ,
414
- param_attr = weight_attrs [2 ],
415
- bias_attr = bias_attrs [2 ])
416
- self .linear2 = paddlenlp .ops .RowParallelLiner (
417
- (dim_feedforward , d_model ),
418
- topo .mp_info .size ,
419
- input_is_parallel = True ,
420
- param_attr = weight_attrs [2 ],
421
- bias_attr = bias_attrs [2 ])
428
+ self .self_attn = MultiHeadAttention (
429
+ d_model ,
430
+ nhead ,
431
+ dropout = attn_dropout ,
432
+ weight_attr = weight_attrs [0 ],
433
+ bias_attr = bias_attrs [0 ],
434
+ topo = topo )
435
+ if topo is None or topo .mp_info .size == 1 :
436
+ self .linear1 = nn .Linear (
437
+ d_model ,
438
+ dim_feedforward ,
439
+ weight_attrs [2 ],
440
+ bias_attr = bias_attrs [2 ])
441
+ self .linear2 = nn .Linear (
442
+ dim_feedforward ,
443
+ d_model ,
444
+ weight_attrs [2 ],
445
+ bias_attr = bias_attrs [2 ])
446
+ else :
447
+ self .linear1 = paddlenlp .ops .ColumnParallelLiner (
448
+ (d_model , dim_feedforward ),
449
+ topo .mp_info .size ,
450
+ gather_out = False ,
451
+ param_attr = weight_attrs [2 ],
452
+ bias_attr = bias_attrs [2 ])
453
+ self .linear2 = paddlenlp .ops .RowParallelLiner (
454
+ (dim_feedforward , d_model ),
455
+ topo .mp_info .size ,
456
+ input_is_parallel = True ,
457
+ param_attr = weight_attrs [2 ],
458
+ bias_attr = bias_attrs [2 ])
422
459
423
- self .norm1 = nn .LayerNorm (d_model , epsilon = 1e-5 )
424
- self .norm2 = nn .LayerNorm (d_model , epsilon = 1e-5 )
425
- self .dropout1 = nn .Dropout (dropout , mode = "upscale_in_train" )
426
- self .dropout2 = nn .Dropout (act_dropout , mode = "upscale_in_train" )
427
- self .activation = getattr (F , activation )
460
+ self .norm1 = nn .LayerNorm (d_model , epsilon = 1e-5 )
461
+ self .norm2 = nn .LayerNorm (d_model , epsilon = 1e-5 )
462
+ self .dropout1 = nn .Dropout (dropout , mode = "upscale_in_train" )
463
+ self .dropout2 = nn .Dropout (act_dropout , mode = "upscale_in_train" )
464
+ self .activation = getattr (F , activation )
428
465
429
466
def forward (self , tgt , memory , tgt_mask = None , use_cache = False , cache = None ):
467
+ if self ._fuse :
468
+ if isinstance (cache , self .Cache ):
469
+ attn_output , cache_kv_out = self .self_attn (
470
+ tgt , attn_mask = tgt_mask , cache = cache .kv )
471
+
472
+ ## if not assign here, update caches in While loop
473
+ # layers.assign(cache_kv_out, cache.kv)
474
+ if use_cache :
475
+ cache = self .Cache (cache_kv_out )
476
+ else :
477
+ attn_output = self .self_attn (tgt , attn_mask = tgt_mask )
478
+
479
+ enc_out = self .ffn (attn_output )
480
+ return (enc_out , cache ) if use_cache else enc_out
481
+
430
482
residual = tgt
431
483
432
484
if self .normalize_before :
@@ -687,7 +739,8 @@ def __init__(self,
687
739
eos_token_id = 7 ,
688
740
bos_token_id = 0 ,
689
741
eol_token_id = 3 ,
690
- topo = None ):
742
+ topo = None ,
743
+ ** kwargs ):
691
744
super (GPTModel , self ).__init__ ()
692
745
693
746
self .pad_token_id = pad_token_id
@@ -727,7 +780,8 @@ def __init__(self,
727
780
initializer = nn .initializer .Normal (
728
781
mean = 0.0 , std = self .initializer_range )),
729
782
bias_attr = None ,
730
- topo = topo ))
783
+ topo = topo ,
784
+ fuse = kwargs .get ('fuse' , False )))
731
785
732
786
if self .pipline_mode :
733
787
Decoder = paddlenlp .ops .guard ('gpu:{}' .format (
@@ -866,7 +920,8 @@ def __init__(self,
866
920
temperature = 1.0 ,
867
921
top_k = 0 ,
868
922
top_p = 1.0 ,
869
- eos_id = None ):
923
+ eos_id = None ,
924
+ ** kwargs ):
870
925
super (GPTForGeneration , self ).__init__ ()
871
926
self .gpt = gpt
872
927
self .apply (self .init_weights )
@@ -879,32 +934,43 @@ def __init__(self,
879
934
self .temperature = temperature
880
935
self .topk = top_k
881
936
self .topp = top_p
882
- self ._fuse = False
883
937
self ._init_gen_cache = False
884
- self .generation_caches = []
938
+ self .generation_caches = None
885
939
self ._dtype = "float32"
940
+ self ._fuse = kwargs .get ("fuse" , False )
886
941
887
942
def _init_generation_caches (self , src_ids ):
888
- if self ._init_gen_cache :
943
+ # not fuse, return None
944
+ if self ._init_gen_cache or self ._fuse is False :
889
945
return self .generation_caches
890
946
947
+ self .generation_caches = []
891
948
num_heads = self .gpt .num_attention_heads
892
949
num_layers = self .gpt .num_hidden_layers
893
950
mp_n_head = num_heads // self .gpt .topo .mp_info .size
894
951
hidden_size = self .gpt .hidden_size
895
952
head_size = hidden_size // num_heads
896
953
for i in range (num_layers ):
897
- k = layers .fill_constant_batch_size_like (
898
- input = src_ids ,
899
- shape = [- 1 , mp_n_head , 0 , head_size ],
900
- dtype = self ._dtype ,
901
- value = 0 )
902
- v = layers .fill_constant_batch_size_like (
903
- input = src_ids ,
904
- shape = [- 1 , mp_n_head , 0 , head_size ],
905
- dtype = self ._dtype ,
906
- value = 0 )
907
- self .generation_caches .append (MultiHeadAttention .Cache (k , v ))
954
+ if self ._fuse :
955
+ kv = layers .fill_constant_batch_size_like (
956
+ input = src_ids ,
957
+ shape = [2 , - 1 , mp_n_head , 0 , head_size ],
958
+ dtype = self ._dtype ,
959
+ value = 0 ,
960
+ output_dim_idx = 1 )
961
+ self .generation_caches .append (TransformerDecoderLayer .Cache (kv ))
962
+ else :
963
+ k = layers .fill_constant_batch_size_like (
964
+ input = src_ids ,
965
+ shape = [- 1 , mp_n_head , 0 , head_size ],
966
+ dtype = self ._dtype ,
967
+ value = 0 )
968
+ v = layers .fill_constant_batch_size_like (
969
+ input = src_ids ,
970
+ shape = [- 1 , mp_n_head , 0 , head_size ],
971
+ dtype = self ._dtype ,
972
+ value = 0 )
973
+ self .generation_caches .append (MultiHeadAttention .Cache (k , v ))
908
974
self ._init_gen_cache = True
909
975
return self .generation_caches
910
976
@@ -1011,10 +1077,14 @@ def forward(self, inputs, use_cache=False, cache=None):
1011
1077
1012
1078
# if cached_kvs are assigned to next step in _prepare_qkv of MultiHeadAttention,
1013
1079
# need to init the global caches here
1014
- # gen_caches = self._init_generation_caches(input_ids)
1080
+ gen_caches = self ._init_generation_caches (input_ids )
1015
1081
1016
1082
logits , cached_kvs = self .model (
1017
- input_ids , position_ids , encode_mask , use_cache = True )
1083
+ input_ids ,
1084
+ position_ids ,
1085
+ encode_mask ,
1086
+ use_cache = True ,
1087
+ cache = gen_caches )
1018
1088
1019
1089
next_id = paddle .argmax (logits [:, - 1 , :], axis = - 1 ).reshape ([- 1 , 1 ])
1020
1090
####################################
@@ -1092,7 +1162,10 @@ def forward(self, inputs, use_cache=False, cache=None):
1092
1162
paddle .assign (layers .cast (cond , dtype = 'bool' ), cond )
1093
1163
if attention_mask :
1094
1164
paddle .assign (decode_mask , attention_mask )
1095
- for i in range (len (decode_cached_kvs )):
1165
+ for i in range (len (decode_cached_kvs )):
1166
+ if self ._fuse :
1167
+ paddle .assign (decode_cached_kvs [i ].kv , cached_kvs [i ].kv )
1168
+ else :
1096
1169
paddle .assign (decode_cached_kvs [i ].k , cached_kvs [i ].k )
1097
1170
paddle .assign (decode_cached_kvs [i ].v , cached_kvs [i ].v )
1098
1171
0 commit comments