Skip to content

Commit cd4dd0a

Browse files
committed
Fix torch end-to-end generation flow on GPU (#1122)
We don't have a good way to test this yet, till we add GPU testing, so that will have to come later.
1 parent 684a2eb commit cd4dd0a

File tree

4 files changed

+20
-0
lines changed

4 files changed

+20
-0
lines changed

keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from absl import logging
2121

2222
from keras_nlp.api_export import keras_nlp_export
23+
from keras_nlp.backend import ops
2324
from keras_nlp.models.bart.bart_preprocessor import BartPreprocessor
2425
from keras_nlp.models.bart.bart_presets import backbone_presets
2526
from keras_nlp.utils.keras_utils import (
@@ -267,6 +268,10 @@ def generate_postprocess(
267268
x["decoder_token_ids"],
268269
x["decoder_padding_mask"],
269270
)
271+
if not isinstance(decoder_token_ids, tf.Tensor):
272+
decoder_token_ids = ops.convert_to_numpy(decoder_token_ids)
273+
if not isinstance(decoder_padding_mask, tf.Tensor):
274+
decoder_padding_mask = ops.convert_to_numpy(decoder_padding_mask)
270275
# Strip any special tokens during detokenization, i.e., the start and
271276
# end markers. In the future, we could make this configurable.
272277
decoder_padding_mask = (

keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from absl import logging
1919

2020
from keras_nlp.api_export import keras_nlp_export
21+
from keras_nlp.backend import ops
2122
from keras_nlp.models.gpt2.gpt2_preprocessor import GPT2Preprocessor
2223
from keras_nlp.utils.keras_utils import (
2324
convert_inputs_to_list_of_tensor_segments,
@@ -164,6 +165,10 @@ def generate_postprocess(
164165
back to a string.
165166
"""
166167
token_ids, padding_mask = x["token_ids"], x["padding_mask"]
168+
if not isinstance(token_ids, tf.Tensor):
169+
token_ids = ops.convert_to_numpy(token_ids)
170+
if not isinstance(padding_mask, tf.Tensor):
171+
padding_mask = ops.convert_to_numpy(padding_mask)
167172
# Strip any special tokens during detokenization (e.g. the start and
168173
# end markers). In the future we could make this configurable.
169174
padding_mask = padding_mask & (token_ids != self.tokenizer.end_token_id)

keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from absl import logging
1919

2020
from keras_nlp.api_export import keras_nlp_export
21+
from keras_nlp.backend import ops
2122
from keras_nlp.models.gpt_neo_x.gpt_neo_x_preprocessor import (
2223
GPTNeoXPreprocessor,
2324
)
@@ -132,6 +133,10 @@ def generate_postprocess(
132133
back to a string.
133134
"""
134135
token_ids, padding_mask = x["token_ids"], x["padding_mask"]
136+
if not isinstance(token_ids, tf.Tensor):
137+
token_ids = ops.convert_to_numpy(token_ids)
138+
if not isinstance(padding_mask, tf.Tensor):
139+
padding_mask = ops.convert_to_numpy(padding_mask)
135140
# Strip any special tokens during detokenization (e.g. the start and
136141
# end markers). In the future we could make this configurable.
137142
padding_mask = padding_mask & (token_ids != self.tokenizer.end_token_id)

keras_nlp/models/opt/opt_causal_lm_preprocessor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from absl import logging
1919

2020
from keras_nlp.api_export import keras_nlp_export
21+
from keras_nlp.backend import ops
2122
from keras_nlp.models.opt.opt_preprocessor import OPTPreprocessor
2223
from keras_nlp.utils.keras_utils import (
2324
convert_inputs_to_list_of_tensor_segments,
@@ -165,6 +166,10 @@ def generate_postprocess(
165166
back to a string.
166167
"""
167168
token_ids, padding_mask = x["token_ids"], x["padding_mask"]
169+
if not isinstance(token_ids, tf.Tensor):
170+
token_ids = ops.convert_to_numpy(token_ids)
171+
if not isinstance(padding_mask, tf.Tensor):
172+
padding_mask = ops.convert_to_numpy(padding_mask)
168173
# Strip any special tokens during detokenization (e.g. the start and
169174
# end markers). In the future we could make this configurable.
170175
padding_mask = padding_mask & (token_ids != self.tokenizer.end_token_id)

0 commit comments

Comments
 (0)