Skip to content

Commit 5c87ada

Browse files
jessechancyJesse Chanmattdangerw
authored
top p search and testing (#233)
* top p search and testing * made filter_value a default 0 * style fixes * minor changes * minor changes and addition of empty prompt checks * Fix typo Co-authored-by: Jesse Chan <[email protected]> Co-authored-by: Matt Watson <[email protected]>
1 parent 31674a1 commit 5c87ada

File tree

2 files changed

+358
-5
lines changed

2 files changed

+358
-5
lines changed

keras_nlp/utils/text_generation.py

Lines changed: 149 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ def validate_prompt(prompt):
2626
)
2727
if not isinstance(prompt, tf.Tensor):
2828
prompt = tf.convert_to_tensor(prompt)
29+
if prompt.shape[-1] == 0:
30+
raise ValueError(
31+
"Length of `prompt` is 0, please provide a non-empty `prompt`."
32+
)
2933
return prompt
3034

3135

@@ -357,7 +361,7 @@ def token_probability_fn(inputs):
357361
"tf.function in eager mode."
358362
)
359363
if k <= 0:
360-
raise ValueError("k should be strictly positive (greater than 0).")
364+
raise ValueError(f"`k` should strictly positive. Received: `k={k}`.")
361365

362366
prompt = validate_prompt(prompt)
363367
input_is_1d = prompt.shape.rank == 1
@@ -393,3 +397,147 @@ def token_probability_fn(inputs):
393397
if input_is_1d:
394398
return tf.squeeze(prompt)
395399
return prompt
400+
401+
402+
def top_p_search(
403+
token_probability_fn,
404+
prompt,
405+
max_length,
406+
p,
407+
seed=None,
408+
from_logits=False,
409+
end_token_id=None,
410+
pad_token_id=0,
411+
):
412+
"""
413+
Text generation utility based on top-p (nucleus) sampling.
414+
415+
Top-p search selects tokens from the smallest subset of output probabilities
416+
that sum to greater than `p`. Put another way, top-p will first order
417+
token predictions by likelihood, and ignore all tokens after the cumulative
418+
probability of selected tokens exceeds `p`. The probability of each
419+
token is provided by `token_probability_fn`.
420+
421+
Args:
422+
token_probability_fn: a callable, which takes in input_sequence
423+
and output the probability distribution of the next token. If
424+
`from_logits` set to True, it should output the logits of the next
425+
token.
426+
prompt: a list or a Tensor, can be 1D or 2D, the initial tokens to
427+
append generated tokens.
428+
max_length: int. The max length of generated text.
429+
p: float. The probability that the top tokens sums up to. Should
430+
follow the constraint of 0 < p < 1.
431+
seed: int, defaults to None. The random seed used for sampling.
432+
from_logits: bool. Indicates whether `token_probability_fn` outputs
433+
logits or probabilities.
434+
end_token_id: int, defaults to None. The token marking the end of the
435+
sequence, once encountered the generation is finished for the exact
436+
sequence. If None, every sequence is generated up to `max_length`.
437+
If set, all tokens after encountering `end_token_id` will be
438+
replaced with `pad_token_id`.
439+
pad_token_id: int, defaults to 0. The pad token after `end_token_id`
440+
is received.
441+
442+
Returns:
443+
A 1D int Tensor, or 2D int Tensor representing the generated
444+
sequences.
445+
446+
Examples:
447+
```python
448+
BATCH_SIZE = 8
449+
VOCAB_SIZE = 10
450+
FEATURE_SIZE = 16
451+
START_ID = 1
452+
END_ID = 2
453+
454+
# Create a dummy model to predict the next token.
455+
model = tf.keras.Sequential(
456+
[
457+
tf.keras.Input(shape=[None]),
458+
tf.keras.layers.Embedding(
459+
input_dim=VOCAB_SIZE,
460+
output_dim=FEATURE_SIZE,
461+
),
462+
tf.keras.layers.Dense(VOCAB_SIZE, activation="softmax"),
463+
]
464+
)
465+
466+
# Define a function that outputs the next token's probability given the
467+
# input sequence.
468+
def token_probability_fn(inputs):
469+
return model(inputs)[:, -1, :]
470+
471+
prompt = tf.fill((BATCH_SIZE, 1), START_ID)
472+
473+
# Print the generated sequence (token ids).
474+
keras_nlp.utils.top_p_search(
475+
token_probability_fn,
476+
prompt,
477+
max_length=10,
478+
p=0.8,
479+
end_token_id=END_ID,
480+
)
481+
```
482+
483+
"""
484+
if not tf.executing_eagerly():
485+
raise RuntimeError(
486+
"`keras_nlp.utils.top_p_search` currently requires an eager "
487+
"execution context. Please call `top_p_search` outside "
488+
"tf.function or run `tf.config.run_functions_eagerly(True)` to run "
489+
"tf.function in eager mode."
490+
)
491+
if p <= 0 or p >= 1:
492+
raise ValueError(
493+
f"`p` should be in the range (0, 1). Received: `p={p}`."
494+
)
495+
496+
prompt = validate_prompt(prompt)
497+
input_is_1d = prompt.shape.rank == 1
498+
if input_is_1d:
499+
prompt = prompt[tf.newaxis, :]
500+
validate_token_probability_fn(token_probability_fn, prompt)
501+
502+
i = prompt.shape[1]
503+
while i < max_length:
504+
# If the prompt has reached our desired length, exit while loop.
505+
pred = token_probability_fn(prompt)
506+
if from_logits:
507+
pred = tf.keras.activations.softmax(pred, axis=-1)
508+
# Sort preds in descending order.
509+
sorted_preds, sorted_indices = tf.math.top_k(
510+
pred, k=pred.shape[1], sorted=True
511+
)
512+
# Calculate cumulative probability distribution.
513+
cumulative_probs = tf.math.cumsum(sorted_preds, axis=-1)
514+
# Create a mask for the tokens to keep.
515+
keep_mask = cumulative_probs <= p
516+
# Shift to include the last token that exceed p.
517+
shifted_keep_mask = tf.concat(
518+
[tf.ones_like(keep_mask[:, :1]), keep_mask[:, :-1]], axis=-1
519+
)
520+
# Filter out unmasked tokens and sample from filtered distribution.
521+
probs = tf.where(
522+
shifted_keep_mask,
523+
sorted_preds,
524+
tf.zeros(pred.shape, dtype=sorted_preds.dtype),
525+
)
526+
sorted_next_token = tf.random.categorical(
527+
tf.math.log(probs), 1, seed=seed
528+
)
529+
next_token = tf.gather_nd(
530+
sorted_indices, sorted_next_token, batch_dims=1
531+
)
532+
next_token = tf.cast(next_token, dtype=prompt.dtype)
533+
# Append the next token to current sequence.
534+
prompt = tf.concat([prompt, next_token[:, tf.newaxis]], axis=-1)
535+
i += 1
536+
537+
if end_token_id is not None:
538+
prompt = mask_tokens_after_end_token(
539+
prompt, max_length, end_token_id, pad_token_id
540+
)
541+
if input_is_1d:
542+
return tf.squeeze(prompt)
543+
return prompt

0 commit comments

Comments
 (0)