@@ -375,13 +375,11 @@ def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer):
375
375
dtype = self .dtype ,
376
376
)
377
377
378
- self .tgt_generation_mask = paddle .zeros (
378
+ self .tgt_generation_mask = paddle .ones (
379
379
shape = [config .batch_size , 1 , 1 , config .total_max_length ],
380
380
dtype = self .dtype ,
381
381
)
382
- self .arange_tensor_encoder = paddle .zeros (
383
- shape = (config .batch_size , 1 , config .total_max_length ), dtype = self .dtype
384
- )
382
+ self .arange_tensor_encoder = paddle .arange (config .total_max_length , dtype = self .dtype )
385
383
386
384
if config .export_precache :
387
385
if config .prefix_path :
@@ -427,7 +425,7 @@ def _postprocess(self, predictions):
427
425
428
426
def _preprocess (self , source ):
429
427
self .attention_mask [:] = 0
430
- self .tgt_generation_mask [:] = 0
428
+ self .tgt_generation_mask [:] = 1
431
429
pre_caches_length = 0 if not self .config .export_precache else self .pre_caches [0 ].shape [- 2 ]
432
430
433
431
if self .tokenizer .chat_template is not None :
@@ -468,15 +466,6 @@ def _preprocess(self, source):
468
466
[prefix_attention_mask , post_attention_mask ], axis = 2
469
467
)
470
468
471
- if self .config .prefix_path is None :
472
- self .tgt_generation_mask [i , 0 , 0 , pre_caches_length : length + pre_caches_length ] = paddle .ones (
473
- shape = [1 , length ], dtype = self .config .dtype
474
- )
475
- else :
476
- self .tgt_generation_mask [i , 0 , 0 , : length + pre_caches_length ] = paddle .ones (
477
- shape = [1 , length + pre_caches_length ], dtype = self .config .dtype
478
- )
479
-
480
469
inputs ["tgt_pos" ] = self .tgt_pos
481
470
elif "bloom" in self .architectures :
482
471
for i in range (inputs ["input_ids" ].shape [0 ]):
@@ -496,20 +485,13 @@ def _preprocess(self, source):
496
485
self .attention_mask [i , :, :length , : length + pre_caches_length ] = paddle .concat (
497
486
[prefix_attention_mask , post_attention_mask ], axis = 2
498
487
)
499
- self .arange_tensor_encoder [i , :, : length + pre_caches_length ] = paddle .arange (
500
- length + pre_caches_length
501
- ).astype (self .config .dtype )
502
488
503
- self .tgt_generation_mask [i , :, 0 , : length + pre_caches_length ] = paddle .ones (
504
- shape = [1 , length + pre_caches_length ], dtype = self .config .dtype
505
- )
506
489
inputs ["tgt_pos" ] = inputs ["tgt_pos" ] + pre_caches_length
507
490
# alibi encoder
508
491
alibi_slopes = get_alibi_slopes (self .model_config .n_head )
509
492
inputs ["position_ids" ] = paddle .to_tensor (alibi_slopes , dtype = "float32" )
510
493
511
- alibi = alibi_slopes [..., None ] * self .arange_tensor_encoder
512
- alibi = alibi [:, :, None , :]
494
+ alibi = alibi_slopes [None , :, None , None ] * self .arange_tensor_encoder
513
495
514
496
if self .model_config .tensor_parallel_degree > 1 :
515
497
block_size = self .model_config .n_head // self .model_config .tensor_parallel_degree
@@ -534,6 +516,9 @@ def _preprocess(self, source):
534
516
self .config .total_max_length ,
535
517
]
536
518
)
519
+ # only generate valid encoder attention mask, other place set 0.
520
+ alibi_encoder [i , :, length :, length :] = 0
521
+
537
522
alibi_decoder = alibi .expand (
538
523
[
539
524
self .config .batch_size ,
@@ -572,15 +557,6 @@ def _preprocess(self, source):
572
557
[prefix_attention_mask , post_attention_mask ], axis = 2
573
558
)
574
559
575
- if self .config .prefix_path is None :
576
- self .tgt_generation_mask [i , 0 , 0 , pre_caches_length : length + pre_caches_length ] = paddle .ones (
577
- shape = [1 , length ], dtype = "float16"
578
- )
579
- else :
580
- self .tgt_generation_mask [i , 0 , 0 , : length + pre_caches_length ] = paddle .ones (
581
- shape = [1 , length + pre_caches_length ], dtype = self .config .dtype
582
- )
583
-
584
560
inputs ["pre_ids" ] = self .pre_ids
585
561
inputs ["attention_mask" ] = self .attention_mask
586
562
inputs ["tgt_generation_mask" ] = self .tgt_generation_mask
0 commit comments