2929from keras_nlp .utils .keras_utils import is_xla_compatible
3030from keras_nlp .utils .python_utils import classproperty
3131from keras_nlp .utils .tf_utils import tensor_to_string_list
32- from keras_nlp .utils .tf_utils import truncate_at_token
3332
3433
3534@keras_nlp_export ("keras_nlp.models.GPT2CausalLM" )
@@ -49,7 +48,7 @@ class GPT2CausalLM(Task):
4948 default, `"top_k"` sampling will be used.
5049
5150 This model can optionally be configured with a `preprocessor` layer, in
52- which case it will automatically apply preprocessing to raw inputs during
51+ which case it will automatically apply preprocessing to string inputs during
5352 `fit()`, `predict()`, `evaluate()` and `generate()`. This is done by default
5453 when creating the model with `from_preset()`.
5554
@@ -306,28 +305,23 @@ def make_generate_function(self):
306305
307306 def generate_step (
308307 self ,
309- token_ids ,
310- padding_mask ,
308+ inputs ,
311309 end_token_id = None ,
312310 ):
313311 """A compilable generation function for a single batch of inputs.
314312
315313 This function represents the inner, XLA-compilable, generation function
316- for a single batch of inputs. It takes in a dense `tf.Tensor` of token
317- ids, and return a dense `tf.Tensor` of token ids, and includes no
318- preprocessing. This function is wrapped by the `generate()` method.
314+ for a single batch of inputs. Inputs should have the same structure as
315+ model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`.
319316
320317 Args:
321- token_ids: A dense int Tensor, with shape
322- `(batch_size, max_length)`. The user provided token ids
323- padded to `max_length`.
324- padding_mask: A dense boolean Tensor, with the same shape as
325- `token_ids`. Positions that are True in the `padding_mask`
326- are assumed to be user input and never updated.
318+ inputs: A dictionary with two keys `"token_ids"` and
319+ `"padding_mask"` and batched tensor values.
327320 end_token_id: The id of the end token to stop on. If all
328321 sequences have produced a new `end_token_id`, generation
329322 will stop.
330323 """
324+ token_ids , padding_mask = inputs ["token_ids" ], inputs ["padding_mask" ]
331325 # Create and seed cache with a single forward pass.
332326 hidden_states , cache = self ._build_cache (token_ids )
333327 # Compute the lengths of all user inputted tokens ids.
@@ -352,7 +346,7 @@ def next(prompt, cache, index):
352346 cache ,
353347 )
354348
355- return self ._sampler (
349+ token_ids = self ._sampler (
356350 next = next ,
357351 prompt = token_ids ,
358352 cache = cache ,
@@ -362,6 +356,78 @@ def next(prompt, cache, index):
362356 hidden_states = hidden_states ,
363357 )
364358
359+ # Compute an output padding mask with the token ids we updated.
360+ if end_token_id is not None :
361+ # Build a mask of `end_token_id` locations not in the original
362+ # prompt (not in locations where `padding_mask` is True).
363+ end_locations = (token_ids == end_token_id ) & (~ padding_mask )
364+ end_locations = tf .cast (end_locations , tf .int32 )
365+ # Use cumsum to get ones in all locations after end_locations.
366+ overflow = tf .math .cumsum (end_locations , exclusive = True )
367+ # Our padding mask is the inverse of these overflow locations.
368+ padding_mask = ~ tf .cast (overflow , tf .bool )
369+ else :
370+ # Without early stopping, all locations will have been updated.
371+ padding_mask = tf .ones_like (token_ids , dtype = tf .bool )
372+ return {
373+ "token_ids" : token_ids ,
374+ "padding_mask" : padding_mask ,
375+ }
376+
377+ def _normalize_generate_inputs (
378+ self ,
379+ inputs ,
380+ ):
381+ """Normalize user input to the generate function.
382+
383+ This function coverts all inputs to tensors, adds a batch dimension if
384+ necessary, and returns a iterable "dataset like" object (either an
385+ actual `tf.data.Dataset` or a list with a single batch element).
386+ """
387+ input_is_scalar = False
388+
389+ if isinstance (inputs , tf .data .Dataset ):
390+ return inputs , input_is_scalar
391+
392+ if isinstance (inputs , str ) or isinstance (inputs , list ):
393+ inputs = tf .convert_to_tensor (inputs )
394+
395+ if isinstance (inputs , tf .Tensor ) and inputs .shape .rank == 0 :
396+ input_is_scalar = True
397+ inputs = inputs [tf .newaxis ]
398+
399+ # We avoid coverting to a dataset purely for speed, for a single batch
400+ # of input, creating a dataset would add significant overhead.
401+ return [inputs ], input_is_scalar
402+
403+ def _normalize_generate_outputs (
404+ self ,
405+ outputs ,
406+ input_is_scalar ,
407+ ):
408+ """Normalize user output from the generate function.
409+
410+ This function converts all output to numpy (for integer output), or
411+ python strings (for string output). If a batch dimension was added to
412+ the input, it is removed from the output (so generate can be string in,
413+ string out).
414+ """
415+
416+ def normalize (x ):
417+ x = tf .concat (x , axis = 0 )
418+ x = tf .squeeze (x , 0 ) if input_is_scalar else x
419+ is_string = x .dtype == tf .string
420+ # Convert outputs to a friendly pythonic type. For numerical outputs
421+ # that is numpy, for string outputs that is `list` and `str`.
422+ return tensor_to_string_list (x ) if is_string else x .numpy ()
423+
424+ if isinstance (outputs [0 ], dict ):
425+ return {
426+ "token_ids" : normalize ([x ["token_ids" ] for x in outputs ]),
427+ "padding_mask" : normalize ([x ["padding_mask" ] for x in outputs ]),
428+ }
429+ return normalize ([x for x in outputs ])
430+
365431 def generate (
366432 self ,
367433 inputs ,
@@ -397,65 +463,43 @@ def generate(
397463 A string or string list if `preprocessor` is set, and a integer
398464 tensor of token IDs if `preprocessor is None`.
399465 """
400- input_is_scalar = False
401-
466+ # Setup our three main passes.
467+ # 1. Optionally preprocessing strings to dense integer tensors.
468+ # 2. Generate new tokens via a compiled function on dense tensors.
469+ # 3. Optionally postprocess dense integer tensors back to string.
470+ generate_function = self .make_generate_function ()
471+ end_token_id = None
402472 if self .preprocessor is not None :
473+ end_token_id = self .preprocessor .tokenizer .end_token_id
403474
404- def preprocess (x ):
405- return self .preprocessor (
406- x ,
407- sequence_length = max_length ,
408- return_labels = False ,
409- # We do not append an end token by default during
410- # generation, as generating directly in the same sequence is
411- # the most common workflow. If an end token directly after
412- # a prompt is desired, it can be added to the prompt string.
413- add_end_token = False ,
414- )
415-
416- if not isinstance (inputs , tf .data .Dataset ):
417- inputs = tf .convert_to_tensor (inputs )
418- input_is_scalar = inputs .shape .rank == 0
419- inputs = inputs [tf .newaxis ] if input_is_scalar else inputs
420- # Wrap a list to avoid the overhead of converting to dataset.
421- inputs = [preprocess (inputs )]
422- else :
475+ def preprocess (x ):
476+ return self .preprocessor .generate_preprocess (
477+ x , sequence_length = max_length
478+ )
479+
480+ def generate (x ):
481+ return generate_function (x , end_token_id = end_token_id )
482+
483+ def postprocess (x ):
484+ return self .preprocessor .generate_postprocess (x )
485+
486+ # Normalize inputs, apply our three passes, and normalize outputs.
487+ inputs , input_is_scalar = self ._normalize_generate_inputs (inputs )
488+
489+ if self .preprocessor is not None :
490+ if isinstance (inputs , tf .data .Dataset ):
423491 inputs = inputs .map (preprocess , tf .data .AUTOTUNE )
424492 inputs = inputs .prefetch (tf .data .AUTOTUNE )
425- else :
426- if not isinstance (inputs , tf .data .Dataset ):
427- # Wrap a list to avoid the overhead of converting to dataset.
428- inputs = [inputs ]
493+ else :
494+ # Fast path for non-dataset, single-batch input.
495+ inputs = [preprocess (x ) for x in inputs ]
429496
430- generate_function = self .make_generate_function ()
431- outputs = []
432- for batch in inputs :
433- token_ids , padding_mask = batch ["token_ids" ], batch ["padding_mask" ]
434- # If `preprocessor` is attached, we can stop after end_token_id.
435- end_token_id = None
436- if self .preprocessor is not None :
437- end_token_id = self .preprocessor .tokenizer .end_token_id
438- # Run the compiled generate function.
439- output = generate_function (token_ids , padding_mask , end_token_id )
440-
441- if self .preprocessor is not None :
442- # Truncate to ragged by removing tokens after the first
443- # generated `end_token_id`.
444- output = truncate_at_token (output , end_token_id , padding_mask )
445- # Strip start token if added.
446- if self .preprocessor .add_start_token :
447- output = output [:, 1 :]
448- # Detokenize.
449- output = self .preprocessor .tokenizer .detokenize (output )
450- outputs .append (output )
451-
452- outputs = tf .concat (outputs , axis = 0 )
453- outputs = tf .squeeze (outputs , 0 ) if input_is_scalar else outputs
454- # Convert outputs to a friendly pythonic type. For numerical outputs
455- # that is numpy, for string outputs that is `list` and `str`.
456- if outputs .dtype == tf .string :
457- return tensor_to_string_list (outputs )
458- return outputs .numpy ()
497+ outputs = [generate (x ) for x in inputs ]
498+
499+ if self .preprocessor is not None :
500+ outputs = [postprocess (x ) for x in outputs ]
501+
502+ return self ._normalize_generate_outputs (outputs , input_is_scalar )
459503
460504 @classmethod
461505 def create_layout_map (cls , mesh ):
0 commit comments