Skip to content

Commit 4a9c758

Browse files
authored
Standalone functions for generate pre/post processing for GPT-2 (#998)
* Standalone functions for generate pre/post processing This decomposes generate in the way we discussed last week, with the goal of leaving the top-level functionality untouched, but allowing a more a granular way to access the preprocessing, postprocessing, and inner dense generation function. Colab [HERE](https://colab.research.google.com/gist/mattdangerw/bb1ef01c1b67255def4a6ad9429de2df/split-preprocessing-demo.ipynb) Other than moving things around in the refactor, there is one major change we need to do here, which is the inner, compiled generate function must also return a padding mask of token ids that were updated. Without this padding mask, the postprocessor would not know where to truncate output before detokenization. To accommodate this I made `generate_function` inputs and outputs a dict with keys "token_ids" and "padding_mask". I actually find this fairly intuitive, with this change `generate_function` has the same inputs and outputs as directly calling the model! ```python generate_function = causal_lm.make_generate_function() generate_function({ "token_ids": [[1, 2, 3, 4, 0, 0, 0, 0]], "padding_mask": [[1, 1, 1, 1, 0, 0, 0, 0]], }) >>> { "token_ids": [[1, 2, 3, 4, 5, 6, 7, 8]], "padding_mask": [[1, 1, 1, 1, 1, 1, 1, 1]], } generate_function({ "token_ids": [[1, 2, 3, 4, 0, 0, 0, 0]], "padding_mask": [[1, 1, 1, 1, 0, 0, 0, 0]], }, end_token_id=6) >>> { "token_ids": [[1, 2, 3, 4, 5, 6, 0, 0]], "padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0]], } ``` * More docstring updates * Fix merge conflict
1 parent 7cbcbed commit 4a9c758

File tree

6 files changed

+211
-132
lines changed

6 files changed

+211
-132
lines changed

keras_nlp/models/gpt2/gpt2_causal_lm.py

Lines changed: 112 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from keras_nlp.utils.keras_utils import is_xla_compatible
3030
from keras_nlp.utils.python_utils import classproperty
3131
from keras_nlp.utils.tf_utils import tensor_to_string_list
32-
from keras_nlp.utils.tf_utils import truncate_at_token
3332

3433

3534
@keras_nlp_export("keras_nlp.models.GPT2CausalLM")
@@ -49,7 +48,7 @@ class GPT2CausalLM(Task):
4948
default, `"top_k"` sampling will be used.
5049
5150
This model can optionally be configured with a `preprocessor` layer, in
52-
which case it will automatically apply preprocessing to raw inputs during
51+
which case it will automatically apply preprocessing to string inputs during
5352
`fit()`, `predict()`, `evaluate()` and `generate()`. This is done by default
5453
when creating the model with `from_preset()`.
5554
@@ -306,28 +305,23 @@ def make_generate_function(self):
306305

307306
def generate_step(
308307
self,
309-
token_ids,
310-
padding_mask,
308+
inputs,
311309
end_token_id=None,
312310
):
313311
"""A compilable generation function for a single batch of inputs.
314312
315313
This function represents the inner, XLA-compilable, generation function
316-
for a single batch of inputs. It takes in a dense `tf.Tensor` of token
317-
ids, and return a dense `tf.Tensor` of token ids, and includes no
318-
preprocessing. This function is wrapped by the `generate()` method.
314+
for a single batch of inputs. Inputs should have the same structure as
315+
model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`.
319316
320317
Args:
321-
token_ids: A dense int Tensor, with shape
322-
`(batch_size, max_length)`. The user provided token ids
323-
padded to `max_length`.
324-
padding_mask: A dense boolean Tensor, with the same shape as
325-
`token_ids`. Positions that are True in the `padding_mask`
326-
are assumed to be user input and never updated.
318+
inputs: A dictionary with two keys `"token_ids"` and
319+
`"padding_mask"` and batched tensor values.
327320
end_token_id: The id of the end token to stop on. If all
328321
sequences have produced a new `end_token_id`, generation
329322
will stop.
330323
"""
324+
token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"]
331325
# Create and seed cache with a single forward pass.
332326
hidden_states, cache = self._build_cache(token_ids)
333327
# Compute the lengths of all user inputted tokens ids.
@@ -352,7 +346,7 @@ def next(prompt, cache, index):
352346
cache,
353347
)
354348

355-
return self._sampler(
349+
token_ids = self._sampler(
356350
next=next,
357351
prompt=token_ids,
358352
cache=cache,
@@ -362,6 +356,78 @@ def next(prompt, cache, index):
362356
hidden_states=hidden_states,
363357
)
364358

359+
# Compute an output padding mask with the token ids we updated.
360+
if end_token_id is not None:
361+
# Build a mask of `end_token_id` locations not in the original
362+
# prompt (not in locations where `padding_mask` is True).
363+
end_locations = (token_ids == end_token_id) & (~padding_mask)
364+
end_locations = tf.cast(end_locations, tf.int32)
365+
# Use cumsum to get ones in all locations after end_locations.
366+
overflow = tf.math.cumsum(end_locations, exclusive=True)
367+
# Our padding mask is the inverse of these overflow locations.
368+
padding_mask = ~tf.cast(overflow, tf.bool)
369+
else:
370+
# Without early stopping, all locations will have been updated.
371+
padding_mask = tf.ones_like(token_ids, dtype=tf.bool)
372+
return {
373+
"token_ids": token_ids,
374+
"padding_mask": padding_mask,
375+
}
376+
377+
def _normalize_generate_inputs(
378+
self,
379+
inputs,
380+
):
381+
"""Normalize user input to the generate function.
382+
383+
This function coverts all inputs to tensors, adds a batch dimension if
384+
necessary, and returns a iterable "dataset like" object (either an
385+
actual `tf.data.Dataset` or a list with a single batch element).
386+
"""
387+
input_is_scalar = False
388+
389+
if isinstance(inputs, tf.data.Dataset):
390+
return inputs, input_is_scalar
391+
392+
if isinstance(inputs, str) or isinstance(inputs, list):
393+
inputs = tf.convert_to_tensor(inputs)
394+
395+
if isinstance(inputs, tf.Tensor) and inputs.shape.rank == 0:
396+
input_is_scalar = True
397+
inputs = inputs[tf.newaxis]
398+
399+
# We avoid coverting to a dataset purely for speed, for a single batch
400+
# of input, creating a dataset would add significant overhead.
401+
return [inputs], input_is_scalar
402+
403+
def _normalize_generate_outputs(
404+
self,
405+
outputs,
406+
input_is_scalar,
407+
):
408+
"""Normalize user output from the generate function.
409+
410+
This function converts all output to numpy (for integer output), or
411+
python strings (for string output). If a batch dimension was added to
412+
the input, it is removed from the output (so generate can be string in,
413+
string out).
414+
"""
415+
416+
def normalize(x):
417+
x = tf.concat(x, axis=0)
418+
x = tf.squeeze(x, 0) if input_is_scalar else x
419+
is_string = x.dtype == tf.string
420+
# Convert outputs to a friendly pythonic type. For numerical outputs
421+
# that is numpy, for string outputs that is `list` and `str`.
422+
return tensor_to_string_list(x) if is_string else x.numpy()
423+
424+
if isinstance(outputs[0], dict):
425+
return {
426+
"token_ids": normalize([x["token_ids"] for x in outputs]),
427+
"padding_mask": normalize([x["padding_mask"] for x in outputs]),
428+
}
429+
return normalize([x for x in outputs])
430+
365431
def generate(
366432
self,
367433
inputs,
@@ -397,65 +463,43 @@ def generate(
397463
A string or string list if `preprocessor` is set, and a integer
398464
tensor of token IDs if `preprocessor is None`.
399465
"""
400-
input_is_scalar = False
401-
466+
# Setup our three main passes.
467+
# 1. Optionally preprocessing strings to dense integer tensors.
468+
# 2. Generate new tokens via a compiled function on dense tensors.
469+
# 3. Optionally postprocess dense integer tensors back to string.
470+
generate_function = self.make_generate_function()
471+
end_token_id = None
402472
if self.preprocessor is not None:
473+
end_token_id = self.preprocessor.tokenizer.end_token_id
403474

404-
def preprocess(x):
405-
return self.preprocessor(
406-
x,
407-
sequence_length=max_length,
408-
return_labels=False,
409-
# We do not append an end token by default during
410-
# generation, as generating directly in the same sequence is
411-
# the most common workflow. If an end token directly after
412-
# a prompt is desired, it can be added to the prompt string.
413-
add_end_token=False,
414-
)
415-
416-
if not isinstance(inputs, tf.data.Dataset):
417-
inputs = tf.convert_to_tensor(inputs)
418-
input_is_scalar = inputs.shape.rank == 0
419-
inputs = inputs[tf.newaxis] if input_is_scalar else inputs
420-
# Wrap a list to avoid the overhead of converting to dataset.
421-
inputs = [preprocess(inputs)]
422-
else:
475+
def preprocess(x):
476+
return self.preprocessor.generate_preprocess(
477+
x, sequence_length=max_length
478+
)
479+
480+
def generate(x):
481+
return generate_function(x, end_token_id=end_token_id)
482+
483+
def postprocess(x):
484+
return self.preprocessor.generate_postprocess(x)
485+
486+
# Normalize inputs, apply our three passes, and normalize outputs.
487+
inputs, input_is_scalar = self._normalize_generate_inputs(inputs)
488+
489+
if self.preprocessor is not None:
490+
if isinstance(inputs, tf.data.Dataset):
423491
inputs = inputs.map(preprocess, tf.data.AUTOTUNE)
424492
inputs = inputs.prefetch(tf.data.AUTOTUNE)
425-
else:
426-
if not isinstance(inputs, tf.data.Dataset):
427-
# Wrap a list to avoid the overhead of converting to dataset.
428-
inputs = [inputs]
493+
else:
494+
# Fast path for non-dataset, single-batch input.
495+
inputs = [preprocess(x) for x in inputs]
429496

430-
generate_function = self.make_generate_function()
431-
outputs = []
432-
for batch in inputs:
433-
token_ids, padding_mask = batch["token_ids"], batch["padding_mask"]
434-
# If `preprocessor` is attached, we can stop after end_token_id.
435-
end_token_id = None
436-
if self.preprocessor is not None:
437-
end_token_id = self.preprocessor.tokenizer.end_token_id
438-
# Run the compiled generate function.
439-
output = generate_function(token_ids, padding_mask, end_token_id)
440-
441-
if self.preprocessor is not None:
442-
# Truncate to ragged by removing tokens after the first
443-
# generated `end_token_id`.
444-
output = truncate_at_token(output, end_token_id, padding_mask)
445-
# Strip start token if added.
446-
if self.preprocessor.add_start_token:
447-
output = output[:, 1:]
448-
# Detokenize.
449-
output = self.preprocessor.tokenizer.detokenize(output)
450-
outputs.append(output)
451-
452-
outputs = tf.concat(outputs, axis=0)
453-
outputs = tf.squeeze(outputs, 0) if input_is_scalar else outputs
454-
# Convert outputs to a friendly pythonic type. For numerical outputs
455-
# that is numpy, for string outputs that is `list` and `str`.
456-
if outputs.dtype == tf.string:
457-
return tensor_to_string_list(outputs)
458-
return outputs.numpy()
497+
outputs = [generate(x) for x in inputs]
498+
499+
if self.preprocessor is not None:
500+
outputs = [postprocess(x) for x in outputs]
501+
502+
return self._normalize_generate_outputs(outputs, input_is_scalar)
459503

460504
@classmethod
461505
def create_layout_map(cls, mesh):

keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py

Lines changed: 73 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -14,23 +14,32 @@
1414

1515
"""GPT2 Causal LM preprocessor layer."""
1616

17+
import tensorflow as tf
1718
from absl import logging
1819

1920
from keras_nlp.api_export import keras_nlp_export
2021
from keras_nlp.models.gpt2.gpt2_preprocessor import GPT2Preprocessor
22+
from keras_nlp.utils.keras_utils import (
23+
convert_inputs_to_list_of_tensor_segments,
24+
)
2125
from keras_nlp.utils.keras_utils import pack_x_y_sample_weight
2226

2327

2428
@keras_nlp_export("keras_nlp.models.GPT2CausalLMPreprocessor")
2529
class GPT2CausalLMPreprocessor(GPT2Preprocessor):
2630
"""GPT2 Causal LM preprocessor.
2731
28-
This preprocessing layer is primarily meant to be used with
32+
This preprocessing layer is meant for use with
2933
`keras_nlp.models.GPT2CausalLM`. By default, it will take in batches of
3034
strings, and return outputs in a `(x, y, sample_weight)` format, where the
31-
`y` label is the next token id in the `x` sequence. For use with generation,
32-
pass `return_labels=False`, in which case the output will simply be the
33-
encoded string features.
35+
`y` label is the next token id in the `x` sequence.
36+
37+
For use with generation, the layer also exposes two methods
38+
`generate_preprocess()` and `generate_postprocess()`. When this preprocessor
39+
is attached to a `keras_nlp.models.GPT2CausalLM` instance, these methods
40+
will be called implicitly in `generate()`. They can also be called
41+
standalone (e.g. to precompute preprocessing inputs for generation in a
42+
separate process).
3443
3544
Args:
3645
tokenizer: A `keras_nlp.models.GPT2Tokenizer` instance.
@@ -47,12 +56,6 @@ class GPT2CausalLMPreprocessor(GPT2Preprocessor):
4756
generates label weights.
4857
sequence_length: Pass to override the configured `sequence_length` of
4958
the layer.
50-
add_start_token: Pass to override the configured value of
51-
`add_start_token` on the layer.
52-
add_end_token: Pass to override the configured value of
53-
`add_end_token` on the layer.
54-
return_labels: If `True`, the output `"token_ids"` will be offset by one
55-
and returned as labels. If `False` only features will be returned.
5659
5760
Examples:
5861
```python
@@ -95,9 +98,6 @@ def call(
9598
y=None,
9699
sample_weight=None,
97100
sequence_length=None,
98-
add_start_token=None,
99-
add_end_token=None,
100-
return_labels=True,
101101
):
102102
if y is not None or sample_weight is not None:
103103
logging.warning(
@@ -106,25 +106,65 @@ def call(
106106
"or `sample_weight`. Your `y` and `sample_weight` will be "
107107
"ignored."
108108
)
109-
if return_labels:
110-
# Tokenize with one extra token to account for the truncation below.
111-
sequence_length = (sequence_length or self.sequence_length) + 1
112-
x = super().call(
109+
sequence_length = sequence_length or self.sequence_length
110+
111+
x = convert_inputs_to_list_of_tensor_segments(x)[0]
112+
x = self.tokenizer(x)
113+
# Pad with one extra token to account for the truncation below.
114+
token_ids, padding_mask = self.packer(
113115
x,
114-
sequence_length=sequence_length,
115-
add_start_token=add_start_token,
116-
add_end_token=add_end_token,
116+
sequence_length=sequence_length + 1,
117+
add_start_value=self.add_start_token,
118+
add_end_value=self.add_end_token,
117119
)
118-
if return_labels:
119-
token_ids, padding_mask = x["token_ids"], x["padding_mask"]
120-
# The last token does not have a next token, so we truncate it out.
121-
x = {
122-
"token_ids": token_ids[..., :-1],
123-
"padding_mask": padding_mask[..., :-1],
124-
}
125-
# Target `y` will be the next token.
126-
y = token_ids[..., 1:]
127-
sample_weight = padding_mask[..., 1:]
128-
return pack_x_y_sample_weight(x, y, sample_weight)
129-
else:
130-
return x
120+
# The last token does not have a next token, so we truncate it out.
121+
x = {
122+
"token_ids": token_ids[..., :-1],
123+
"padding_mask": padding_mask[..., :-1],
124+
}
125+
# Target `y` will be the next token.
126+
y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:]
127+
return pack_x_y_sample_weight(x, y, sample_weight)
128+
129+
def generate_preprocess(
130+
self,
131+
x,
132+
sequence_length=None,
133+
):
134+
"""Covert strings to integer token input for generation.
135+
136+
Similar to calling the layer for training, this method takes in strings
137+
or tensor strings, tokenizes and packs the input, and computes a padding
138+
mask masking all inputs not filled in with a padded value.
139+
140+
Unlike calling the the layer for training, this method does not compute
141+
labels and will never append a `tokenizer.end_token_id` to the end of
142+
the sequence (as generation is expected to continue at the end of the
143+
inputted prompt).
144+
"""
145+
x = convert_inputs_to_list_of_tensor_segments(x)[0]
146+
x = self.tokenizer(x)
147+
token_ids, padding_mask = self.packer(
148+
x, sequence_length=sequence_length, add_end_value=False
149+
)
150+
return {
151+
"token_ids": token_ids,
152+
"padding_mask": padding_mask,
153+
}
154+
155+
def generate_postprocess(
156+
self,
157+
x,
158+
):
159+
"""Covert integer token output to strings for generation.
160+
161+
This method reverses `generate_preprocess()`, by first removing all
162+
padding and start/end tokens, and then converting the interger sequence
163+
back to a string.
164+
"""
165+
token_ids, padding_mask = x["token_ids"], x["padding_mask"]
166+
# Strip any special tokens during detokenization (e.g. the start and
167+
# end markers). In the future we could make this configurable.
168+
padding_mask = padding_mask & (token_ids != self.tokenizer.end_token_id)
169+
token_ids = tf.ragged.boolean_mask(token_ids, padding_mask)
170+
return self.tokenizer.detokenize(token_ids)

0 commit comments

Comments
 (0)