@@ -26,6 +26,10 @@ def validate_prompt(prompt):
2626 )
2727 if not isinstance (prompt , tf .Tensor ):
2828 prompt = tf .convert_to_tensor (prompt )
29+ if prompt .shape [- 1 ] == 0 :
30+ raise ValueError (
31+ "Length of `prompt` is 0, please provide a non-empty `prompt`."
32+ )
2933 return prompt
3034
3135
@@ -357,7 +361,7 @@ def token_probability_fn(inputs):
357361 "tf.function in eager mode."
358362 )
359363 if k <= 0 :
360- raise ValueError ("k should be strictly positive (greater than 0) ." )
364+ raise ValueError (f"`k` should strictly positive. Received: `k= { k } ` ." )
361365
362366 prompt = validate_prompt (prompt )
363367 input_is_1d = prompt .shape .rank == 1
@@ -393,3 +397,147 @@ def token_probability_fn(inputs):
393397 if input_is_1d :
394398 return tf .squeeze (prompt )
395399 return prompt
400+
401+
402+ def top_p_search (
403+ token_probability_fn ,
404+ prompt ,
405+ max_length ,
406+ p ,
407+ seed = None ,
408+ from_logits = False ,
409+ end_token_id = None ,
410+ pad_token_id = 0 ,
411+ ):
412+ """
413+ Text generation utility based on top-p (nucleus) sampling.
414+
415+ Top-p search selects tokens from the smallest subset of output probabilities
416+ that sum to greater than `p`. Put another way, top-p will first order
417+ token predictions by likelihood, and ignore all tokens after the cumulative
418+ probability of selected tokens exceeds `p`. The probability of each
419+ token is provided by `token_probability_fn`.
420+
421+ Args:
422+ token_probability_fn: a callable, which takes in input_sequence
423+ and output the probability distribution of the next token. If
424+ `from_logits` set to True, it should output the logits of the next
425+ token.
426+ prompt: a list or a Tensor, can be 1D or 2D, the initial tokens to
427+ append generated tokens.
428+ max_length: int. The max length of generated text.
429+ p: float. The probability that the top tokens sums up to. Should
430+ follow the constraint of 0 < p < 1.
431+ seed: int, defaults to None. The random seed used for sampling.
432+ from_logits: bool. Indicates whether `token_probability_fn` outputs
433+ logits or probabilities.
434+ end_token_id: int, defaults to None. The token marking the end of the
435+ sequence, once encountered the generation is finished for the exact
436+ sequence. If None, every sequence is generated up to `max_length`.
437+ If set, all tokens after encountering `end_token_id` will be
438+ replaced with `pad_token_id`.
439+ pad_token_id: int, defaults to 0. The pad token after `end_token_id`
440+ is received.
441+
442+ Returns:
443+ A 1D int Tensor, or 2D int Tensor representing the generated
444+ sequences.
445+
446+ Examples:
447+ ```python
448+ BATCH_SIZE = 8
449+ VOCAB_SIZE = 10
450+ FEATURE_SIZE = 16
451+ START_ID = 1
452+ END_ID = 2
453+
454+ # Create a dummy model to predict the next token.
455+ model = tf.keras.Sequential(
456+ [
457+ tf.keras.Input(shape=[None]),
458+ tf.keras.layers.Embedding(
459+ input_dim=VOCAB_SIZE,
460+ output_dim=FEATURE_SIZE,
461+ ),
462+ tf.keras.layers.Dense(VOCAB_SIZE, activation="softmax"),
463+ ]
464+ )
465+
466+ # Define a function that outputs the next token's probability given the
467+ # input sequence.
468+ def token_probability_fn(inputs):
469+ return model(inputs)[:, -1, :]
470+
471+ prompt = tf.fill((BATCH_SIZE, 1), START_ID)
472+
473+ # Print the generated sequence (token ids).
474+ keras_nlp.utils.top_p_search(
475+ token_probability_fn,
476+ prompt,
477+ max_length=10,
478+ p=0.8,
479+ end_token_id=END_ID,
480+ )
481+ ```
482+
483+ """
484+ if not tf .executing_eagerly ():
485+ raise RuntimeError (
486+ "`keras_nlp.utils.top_p_search` currently requires an eager "
487+ "execution context. Please call `top_p_search` outside "
488+ "tf.function or run `tf.config.run_functions_eagerly(True)` to run "
489+ "tf.function in eager mode."
490+ )
491+ if p <= 0 or p >= 1 :
492+ raise ValueError (
493+ f"`p` should be in the range (0, 1). Received: `p={ p } `."
494+ )
495+
496+ prompt = validate_prompt (prompt )
497+ input_is_1d = prompt .shape .rank == 1
498+ if input_is_1d :
499+ prompt = prompt [tf .newaxis , :]
500+ validate_token_probability_fn (token_probability_fn , prompt )
501+
502+ i = prompt .shape [1 ]
503+ while i < max_length :
504+ # If the prompt has reached our desired length, exit while loop.
505+ pred = token_probability_fn (prompt )
506+ if from_logits :
507+ pred = tf .keras .activations .softmax (pred , axis = - 1 )
508+ # Sort preds in descending order.
509+ sorted_preds , sorted_indices = tf .math .top_k (
510+ pred , k = pred .shape [1 ], sorted = True
511+ )
512+ # Calculate cumulative probability distribution.
513+ cumulative_probs = tf .math .cumsum (sorted_preds , axis = - 1 )
514+ # Create a mask for the tokens to keep.
515+ keep_mask = cumulative_probs <= p
516+ # Shift to include the last token that exceed p.
517+ shifted_keep_mask = tf .concat (
518+ [tf .ones_like (keep_mask [:, :1 ]), keep_mask [:, :- 1 ]], axis = - 1
519+ )
520+ # Filter out unmasked tokens and sample from filtered distribution.
521+ probs = tf .where (
522+ shifted_keep_mask ,
523+ sorted_preds ,
524+ tf .zeros (pred .shape , dtype = sorted_preds .dtype ),
525+ )
526+ sorted_next_token = tf .random .categorical (
527+ tf .math .log (probs ), 1 , seed = seed
528+ )
529+ next_token = tf .gather_nd (
530+ sorted_indices , sorted_next_token , batch_dims = 1
531+ )
532+ next_token = tf .cast (next_token , dtype = prompt .dtype )
533+ # Append the next token to current sequence.
534+ prompt = tf .concat ([prompt , next_token [:, tf .newaxis ]], axis = - 1 )
535+ i += 1
536+
537+ if end_token_id is not None :
538+ prompt = mask_tokens_after_end_token (
539+ prompt , max_length , end_token_id , pad_token_id
540+ )
541+ if input_is_1d :
542+ return tf .squeeze (prompt )
543+ return prompt
0 commit comments