Skip to content

Commit 57504e7

Browse files
去掉内部关于tgt_attention_mask的一些修改 (#7696)
* commit * 删掉一些东西
1 parent 228bd14 commit 57504e7

File tree

5 files changed

+18
-322
lines changed

5 files changed

+18
-322
lines changed

csrc/generation/set_alibi_mask_value.cu

Lines changed: 0 additions & 136 deletions
This file was deleted.

csrc/generation/set_mask_value.cu

Lines changed: 0 additions & 123 deletions
This file was deleted.

csrc/setup_cuda.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ def get_gencode_flags():
5555
ext_modules=CUDAExtension(
5656
sources=[
5757
"./generation/save_with_output.cc",
58-
"./generation/set_mask_value.cu",
5958
"./generation/set_value_by_flags.cu",
6059
"./generation/token_penalty_multi_scores.cu",
6160
"./generation/stop_generation_multi_ends.cu",
@@ -66,7 +65,6 @@ def get_gencode_flags():
6665
"./generation/transpose_removing_padding.cu",
6766
"./generation/write_cache_kv.cu",
6867
"./generation/encode_rotary_qk.cu",
69-
"./generation/set_alibi_mask_value.cu",
7068
"./generation/quant_int8.cu",
7169
"./generation/dequant_int8.cu",
7270
],

llm/predictor.py

Lines changed: 7 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -375,13 +375,11 @@ def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer):
375375
dtype=self.dtype,
376376
)
377377

378-
self.tgt_generation_mask = paddle.zeros(
378+
self.tgt_generation_mask = paddle.ones(
379379
shape=[config.batch_size, 1, 1, config.total_max_length],
380380
dtype=self.dtype,
381381
)
382-
self.arange_tensor_encoder = paddle.zeros(
383-
shape=(config.batch_size, 1, config.total_max_length), dtype=self.dtype
384-
)
382+
self.arange_tensor_encoder = paddle.arange(config.total_max_length, dtype=self.dtype)
385383

386384
if config.export_precache:
387385
if config.prefix_path:
@@ -427,7 +425,7 @@ def _postprocess(self, predictions):
427425

428426
def _preprocess(self, source):
429427
self.attention_mask[:] = 0
430-
self.tgt_generation_mask[:] = 0
428+
self.tgt_generation_mask[:] = 1
431429
pre_caches_length = 0 if not self.config.export_precache else self.pre_caches[0].shape[-2]
432430

433431
if self.tokenizer.chat_template is not None:
@@ -468,15 +466,6 @@ def _preprocess(self, source):
468466
[prefix_attention_mask, post_attention_mask], axis=2
469467
)
470468

471-
if self.config.prefix_path is None:
472-
self.tgt_generation_mask[i, 0, 0, pre_caches_length : length + pre_caches_length] = paddle.ones(
473-
shape=[1, length], dtype=self.config.dtype
474-
)
475-
else:
476-
self.tgt_generation_mask[i, 0, 0, : length + pre_caches_length] = paddle.ones(
477-
shape=[1, length + pre_caches_length], dtype=self.config.dtype
478-
)
479-
480469
inputs["tgt_pos"] = self.tgt_pos
481470
elif "bloom" in self.architectures:
482471
for i in range(inputs["input_ids"].shape[0]):
@@ -496,20 +485,13 @@ def _preprocess(self, source):
496485
self.attention_mask[i, :, :length, : length + pre_caches_length] = paddle.concat(
497486
[prefix_attention_mask, post_attention_mask], axis=2
498487
)
499-
self.arange_tensor_encoder[i, :, : length + pre_caches_length] = paddle.arange(
500-
length + pre_caches_length
501-
).astype(self.config.dtype)
502488

503-
self.tgt_generation_mask[i, :, 0, : length + pre_caches_length] = paddle.ones(
504-
shape=[1, length + pre_caches_length], dtype=self.config.dtype
505-
)
506489
inputs["tgt_pos"] = inputs["tgt_pos"] + pre_caches_length
507490
# alibi encoder
508491
alibi_slopes = get_alibi_slopes(self.model_config.n_head)
509492
inputs["position_ids"] = paddle.to_tensor(alibi_slopes, dtype="float32")
510493

511-
alibi = alibi_slopes[..., None] * self.arange_tensor_encoder
512-
alibi = alibi[:, :, None, :]
494+
alibi = alibi_slopes[None, :, None, None] * self.arange_tensor_encoder
513495

514496
if self.model_config.tensor_parallel_degree > 1:
515497
block_size = self.model_config.n_head // self.model_config.tensor_parallel_degree
@@ -534,6 +516,9 @@ def _preprocess(self, source):
534516
self.config.total_max_length,
535517
]
536518
)
519+
# only generate valid encoder attention mask, other place set 0.
520+
alibi_encoder[i, :, length:, length:] = 0
521+
537522
alibi_decoder = alibi.expand(
538523
[
539524
self.config.batch_size,
@@ -572,15 +557,6 @@ def _preprocess(self, source):
572557
[prefix_attention_mask, post_attention_mask], axis=2
573558
)
574559

575-
if self.config.prefix_path is None:
576-
self.tgt_generation_mask[i, 0, 0, pre_caches_length : length + pre_caches_length] = paddle.ones(
577-
shape=[1, length], dtype="float16"
578-
)
579-
else:
580-
self.tgt_generation_mask[i, 0, 0, : length + pre_caches_length] = paddle.ones(
581-
shape=[1, length + pre_caches_length], dtype=self.config.dtype
582-
)
583-
584560
inputs["pre_ids"] = self.pre_ids
585561
inputs["attention_mask"] = self.attention_mask
586562
inputs["tgt_generation_mask"] = self.tgt_generation_mask

0 commit comments

Comments
 (0)