Skip to content

Commit e468e19

Browse files
authored
add attn_mask input for encoder-decoder (#1431)
1 parent 10ac335 commit e468e19

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

paddlenlp/ops/faster_transformer/transformer/faster_transformer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,7 +1103,8 @@ def forward(self,
11031103
self.encoder = enable_faster_encoder(self.encoder, need_build=False)
11041104
if encoder_output is None:
11051105
assert input_ids is not None, "You have to specify either input_ids or encoder_output."
1106-
encoder_output = self.encoder(input_ids)
1106+
encoder_output = self.prepare_encoder_decoder_kwargs_for_generation(
1107+
input_ids, model_kwargs)["encoder_output"]
11071108
self.encoder = disable_faster_encoder(self.encoder)
11081109
if seq_len is None:
11091110
assert input_ids is not None, "You have to specify either input_ids when generating seq_len."
@@ -1206,7 +1207,8 @@ def forward(self,
12061207
self.encoder = enable_faster_encoder(self.encoder, need_build=False)
12071208
if encoder_output is None:
12081209
assert input_ids is not None, "You have to specify either input_ids or encoder_output."
1209-
encoder_output = self.encoder(input_ids)
1210+
encoder_output = self.prepare_encoder_decoder_kwargs_for_generation(
1211+
input_ids, model_kwargs)["encoder_output"]
12101212
self.encoder = disable_faster_encoder(self.encoder)
12111213
batch_size = paddle.shape(encoder_output)[0]
12121214
if seq_len is None:

0 commit comments

Comments
 (0)