File tree Expand file tree Collapse file tree 1 file changed +4
-2
lines changed
paddlenlp/ops/faster_transformer/transformer Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Original file line number Diff line number Diff line change @@ -1103,7 +1103,8 @@ def forward(self,
1103
1103
self .encoder = enable_faster_encoder (self .encoder , need_build = False )
1104
1104
if encoder_output is None :
1105
1105
assert input_ids is not None , "You have to specify either input_ids or encoder_output."
1106
- encoder_output = self .encoder (input_ids )
1106
+ encoder_output = self .prepare_encoder_decoder_kwargs_for_generation (
1107
+ input_ids , model_kwargs )["encoder_output" ]
1107
1108
self .encoder = disable_faster_encoder (self .encoder )
1108
1109
if seq_len is None :
1109
1110
assert input_ids is not None , "You have to specify either input_ids when generating seq_len."
@@ -1206,7 +1207,8 @@ def forward(self,
1206
1207
self .encoder = enable_faster_encoder (self .encoder , need_build = False )
1207
1208
if encoder_output is None :
1208
1209
assert input_ids is not None , "You have to specify either input_ids or encoder_output."
1209
- encoder_output = self .encoder (input_ids )
1210
+ encoder_output = self .prepare_encoder_decoder_kwargs_for_generation (
1211
+ input_ids , model_kwargs )["encoder_output" ]
1210
1212
self .encoder = disable_faster_encoder (self .encoder )
1211
1213
batch_size = paddle .shape (encoder_output )[0 ]
1212
1214
if seq_len is None :
You can’t perform that action at this time.
0 commit comments