Skip to content

Commit 31674a1

Browse files
authored
added top k search util (#232)
* added top k search util * reformat files * added edits from comments and function validating * Added optional parameter and testing * edited docstring and changed format of from_logit * minor changes
1 parent 65349af commit 31674a1

File tree

3 files changed

+378
-16
lines changed

3 files changed

+378
-16
lines changed

keras_nlp/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@
1414

1515
from keras_nlp.utils.text_generation import greedy_search
1616
from keras_nlp.utils.text_generation import random_search
17+
from keras_nlp.utils.text_generation import top_k_search

keras_nlp/utils/text_generation.py

Lines changed: 164 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,7 @@
1818

1919

2020
def validate_prompt(prompt):
21-
"""
22-
Helper function to validate input to text_generation utils.
23-
"""
21+
"""Helper function to validate input to text_generation utils."""
2422
if isinstance(prompt, tf.RaggedTensor):
2523
raise ValueError(
2624
"RaggedTensor `prompt` is not supported, please "
@@ -31,10 +29,19 @@ def validate_prompt(prompt):
3129
return prompt
3230

3331

32+
def validate_token_probability_fn(token_probability_fn, prompt):
33+
"""Helper function to validate token probability fn output"""
34+
test_pred = token_probability_fn(prompt)
35+
if len(test_pred.shape) != 2:
36+
raise ValueError(
37+
"Output of `token_probability_fn` is not a 2D tensor, "
38+
"please provide a function with the output shape "
39+
"[batch_size, vocab_size]."
40+
)
41+
42+
3443
def mask_tokens_after_end_token(prompt, max_length, end_token_id, pad_token_id):
35-
"""
36-
Helper function to mask the tokens after the end token.
37-
"""
44+
"""Helper function to mask the tokens after the end token."""
3845
# Mask out tokens after `end_token_id` is encountered.
3946
# Find index of first end_token_id.
4047
end_indices = tf.math.argmax(prompt == end_token_id, -1)
@@ -61,7 +68,8 @@ def greedy_search(
6168
6269
Args:
6370
token_probability_fn: a callable, which takes in input_sequence
64-
and output the probability distribution of the next token.
71+
and output the probability distribution or the logits of the next
72+
token.
6573
prompt: a list or a Tensor, can be 1D or 2D, the initial tokens to
6674
append generated tokens.
6775
max_length: int. The max length of generated text.
@@ -79,8 +87,11 @@ def greedy_search(
7987
8088
Examples:
8189
```python
90+
BATCH_SIZE = 8
8291
VOCAB_SIZE = 10
8392
FEATURE_SIZE = 16
93+
START_ID = 1
94+
END_ID = 2
8495
8596
# Create a dummy model to predict the next token.
8697
model = tf.keras.Sequential(
@@ -99,14 +110,15 @@ def greedy_search(
99110
def token_probability_fn(inputs):
100111
return model(inputs)[:, -1, :]
101112
102-
prompt = tf.random.uniform(shape=[5, 5], maxval=VOCAB_SIZE, dtype=tf.int64)
113+
prompt = tf.fill((BATCH_SIZE, 1), START_ID)
103114
104115
# Print the generated sequence (token ids).
105116
keras_nlp.utils.greedy_search(
106117
token_probability_fn,
107118
prompt,
108119
max_length=10,
109-
end_token_id=0,)
120+
end_token_id=END_ID,
121+
)
110122
```
111123
112124
"""
@@ -123,6 +135,7 @@ def token_probability_fn(inputs):
123135
input_is_1d = prompt.shape.rank == 1
124136
if input_is_1d:
125137
prompt = prompt[tf.newaxis, :]
138+
validate_token_probability_fn(token_probability_fn, prompt)
126139

127140
i = prompt.shape[1]
128141
while i < max_length:
@@ -148,6 +161,7 @@ def random_search(
148161
prompt,
149162
max_length,
150163
seed=None,
164+
from_logits=False,
151165
end_token_id=None,
152166
pad_token_id=0,
153167
):
@@ -160,11 +174,15 @@ def random_search(
160174
161175
Args:
162176
token_probability_fn: a callable, which takes in input_sequence
163-
and output the probability distribution of the next token.
177+
and output the probability distribution of the next token. If
178+
`from_logits` set to True, it should output the logits of the next
179+
token.
164180
prompt: a list or a Tensor, can be 1D or 2D, the initial tokens to
165181
append generated tokens.
166182
max_length: int. The max length of generated text.
167183
seed: int, defaults to None. The random seed used for sampling.
184+
from_logits: bool. Indicates whether `token_probability_fn` outputs
185+
logits or probabilities.
168186
end_token_id: int, defaults to None. The token marking the end of the
169187
sequence, once encountered the generation is finished for the exact
170188
sequence. If None, every sequence is generated up to `max_length`.
@@ -179,8 +197,11 @@ def random_search(
179197
180198
Examples:
181199
```python
200+
BATCH_SIZE = 8
182201
VOCAB_SIZE = 10
183202
FEATURE_SIZE = 16
203+
START_ID = 1
204+
END_ID = 2
184205
185206
# Create a dummy model to predict the next token.
186207
model = tf.keras.Sequential(
@@ -199,14 +220,15 @@ def random_search(
199220
def token_probability_fn(inputs):
200221
return model(inputs)[:, -1, :]
201222
202-
prompt = tf.random.uniform(shape=[5, 5], maxval=VOCAB_SIZE, dtype=tf.int64)
223+
prompt = tf.fill((BATCH_SIZE, 1), START_ID)
203224
204225
# Print the generated sequence (token ids).
205-
keras_nlp.utils.random_sampling(
226+
keras_nlp.utils.random_search(
206227
token_probability_fn,
207228
prompt,
208229
max_length=10,
209-
end_token_id=0,)
230+
end_token_id=END_ID,
231+
)
210232
```
211233
212234
"""
@@ -222,11 +244,14 @@ def token_probability_fn(inputs):
222244
input_is_1d = prompt.shape.rank == 1
223245
if input_is_1d:
224246
prompt = prompt[tf.newaxis, :]
247+
validate_token_probability_fn(token_probability_fn, prompt)
225248

226249
i = prompt.shape[1]
227250
while i < max_length:
228251
# If the prompt has reached our desired length, exit while loop.
229252
pred = token_probability_fn(prompt)
253+
if from_logits:
254+
pred = tf.keras.activations.softmax(pred, axis=-1)
230255
next_token = tf.cast(
231256
tf.random.categorical(tf.math.log(pred), 1, seed=seed),
232257
dtype=prompt.dtype,
@@ -242,3 +267,129 @@ def token_probability_fn(inputs):
242267
if input_is_1d:
243268
return tf.squeeze(prompt)
244269
return prompt
270+
271+
272+
def top_k_search(
273+
token_probability_fn,
274+
prompt,
275+
max_length,
276+
k,
277+
seed=None,
278+
from_logits=False,
279+
end_token_id=None,
280+
pad_token_id=0,
281+
):
282+
"""
283+
Text generation utility based on top-k sampling.
284+
285+
Top-k search samples the next token from the top-k tokens in the
286+
probability distribution provided by `token_probability_fn` and appends it
287+
to the existing sequence.
288+
289+
Args:
290+
token_probability_fn: a callable, which takes in input_sequence
291+
and output the probability distribution of the next token. If
292+
`from_logits` set to True, it should output the logits of the next
293+
token.
294+
prompt: a list or a Tensor, can be 1D or 2D, the initial tokens to
295+
append generated tokens.
296+
max_length: int. The max length of generated text.
297+
k: int. The number of top tokens to sample from. Should be non-negative
298+
and less than the vocabulary size.
299+
seed: int, defaults to None. The random seed used for sampling.
300+
from_logits: bool. Indicates whether `token_probability_fn` outputs
301+
logits or probabilities.
302+
end_token_id: int, defaults to None. The token marking the end of the
303+
sequence, once encountered the generation is finished for the exact
304+
sequence. If None, every sequence is generated up to `max_length`.
305+
If set, all tokens after encountering `end_token_id` will be
306+
replaced with `pad_token_id`.
307+
pad_token_id: int, defaults to 0. The pad token after `end_token_id`
308+
is received.
309+
310+
Returns:
311+
A 1D int Tensor, or 2D int Tensor representing the generated
312+
sequences.
313+
314+
Examples:
315+
```python
316+
BATCH_SIZE = 8
317+
VOCAB_SIZE = 10
318+
FEATURE_SIZE = 16
319+
START_ID = 1
320+
END_ID = 2
321+
322+
# Create a dummy model to predict the next token.
323+
model = tf.keras.Sequential(
324+
[
325+
tf.keras.Input(shape=[None]),
326+
tf.keras.layers.Embedding(
327+
input_dim=VOCAB_SIZE,
328+
output_dim=FEATURE_SIZE,
329+
),
330+
tf.keras.layers.Dense(VOCAB_SIZE, activation="softmax"),
331+
]
332+
)
333+
334+
# Define a function that outputs the next token's probability given the
335+
# input sequence.
336+
def token_probability_fn(inputs):
337+
return model(inputs)[:, -1, :]
338+
339+
prompt = tf.fill((BATCH_SIZE, 1), START_ID)
340+
341+
# Print the generated sequence (token ids).
342+
keras_nlp.utils.top_k_search(
343+
token_probability_fn,
344+
prompt,
345+
max_length=10,
346+
k=4,
347+
end_token_id=END_ID,
348+
)
349+
```
350+
351+
"""
352+
if not tf.executing_eagerly():
353+
raise RuntimeError(
354+
"`keras_nlp.utils.top_k_search` currently requires an eager "
355+
"execution context. Please call `top_k_search` outside "
356+
"tf.function or run `tf.config.run_functions_eagerly(True)` to run "
357+
"tf.function in eager mode."
358+
)
359+
if k <= 0:
360+
raise ValueError("k should be strictly positive (greater than 0).")
361+
362+
prompt = validate_prompt(prompt)
363+
input_is_1d = prompt.shape.rank == 1
364+
if input_is_1d:
365+
prompt = prompt[tf.newaxis, :]
366+
validate_token_probability_fn(token_probability_fn, prompt)
367+
368+
i = prompt.shape[1]
369+
while i < max_length:
370+
# If the prompt has reached our desired length, exit while loop.
371+
pred = token_probability_fn(prompt)
372+
if from_logits:
373+
pred = tf.keras.activations.softmax(pred, axis=-1)
374+
# If k is greater than the vocabulary size, use the entire vocabulary.
375+
k = min(k, pred.shape[1])
376+
# Filter out top-k tokens.
377+
top_k_pred, top_k_indices = tf.math.top_k(pred, k=k)
378+
# Sample the next token from the probability distribution.
379+
next_token = tf.random.categorical(
380+
tf.math.log(top_k_pred), 1, seed=seed
381+
)
382+
# Rearrange to get the next token idx from the original order.
383+
next_token = tf.gather_nd(top_k_indices, next_token, batch_dims=1)
384+
next_token = tf.cast(next_token, dtype=prompt.dtype)
385+
# Append the next token to current sequence.
386+
prompt = tf.concat([prompt, next_token[:, tf.newaxis]], axis=-1)
387+
i += 1
388+
389+
if end_token_id is not None:
390+
prompt = mask_tokens_after_end_token(
391+
prompt, max_length, end_token_id, pad_token_id
392+
)
393+
if input_is_1d:
394+
return tf.squeeze(prompt)
395+
return prompt

0 commit comments

Comments
 (0)