1818
1919
2020def validate_prompt (prompt ):
21- """
22- Helper function to validate input to text_generation utils.
23- """
21+ """Helper function to validate input to text_generation utils."""
2422 if isinstance (prompt , tf .RaggedTensor ):
2523 raise ValueError (
2624 "RaggedTensor `prompt` is not supported, please "
@@ -31,10 +29,19 @@ def validate_prompt(prompt):
3129 return prompt
3230
3331
32+ def validate_token_probability_fn (token_probability_fn , prompt ):
33+ """Helper function to validate token probability fn output"""
34+ test_pred = token_probability_fn (prompt )
35+ if len (test_pred .shape ) != 2 :
36+ raise ValueError (
37+ "Output of `token_probability_fn` is not a 2D tensor, "
38+ "please provide a function with the output shape "
39+ "[batch_size, vocab_size]."
40+ )
41+
42+
3443def mask_tokens_after_end_token (prompt , max_length , end_token_id , pad_token_id ):
35- """
36- Helper function to mask the tokens after the end token.
37- """
44+ """Helper function to mask the tokens after the end token."""
3845 # Mask out tokens after `end_token_id` is encountered.
3946 # Find index of first end_token_id.
4047 end_indices = tf .math .argmax (prompt == end_token_id , - 1 )
@@ -61,7 +68,8 @@ def greedy_search(
6168
6269 Args:
6370 token_probability_fn: a callable, which takes in input_sequence
64- and output the probability distribution of the next token.
71+ and output the probability distribution or the logits of the next
72+ token.
6573 prompt: a list or a Tensor, can be 1D or 2D, the initial tokens to
6674 append generated tokens.
6775 max_length: int. The max length of generated text.
@@ -79,8 +87,11 @@ def greedy_search(
7987
8088 Examples:
8189 ```python
90+ BATCH_SIZE = 8
8291 VOCAB_SIZE = 10
8392 FEATURE_SIZE = 16
93+ START_ID = 1
94+ END_ID = 2
8495
8596 # Create a dummy model to predict the next token.
8697 model = tf.keras.Sequential(
@@ -99,14 +110,15 @@ def greedy_search(
99110 def token_probability_fn(inputs):
100111 return model(inputs)[:, -1, :]
101112
102- prompt = tf.random.uniform(shape=[5, 5], maxval=VOCAB_SIZE, dtype=tf.int64 )
113+ prompt = tf.fill((BATCH_SIZE, 1), START_ID )
103114
104115 # Print the generated sequence (token ids).
105116 keras_nlp.utils.greedy_search(
106117 token_probability_fn,
107118 prompt,
108119 max_length=10,
109- end_token_id=0,)
120+ end_token_id=END_ID,
121+ )
110122 ```
111123
112124 """
@@ -123,6 +135,7 @@ def token_probability_fn(inputs):
123135 input_is_1d = prompt .shape .rank == 1
124136 if input_is_1d :
125137 prompt = prompt [tf .newaxis , :]
138+ validate_token_probability_fn (token_probability_fn , prompt )
126139
127140 i = prompt .shape [1 ]
128141 while i < max_length :
@@ -148,6 +161,7 @@ def random_search(
148161 prompt ,
149162 max_length ,
150163 seed = None ,
164+ from_logits = False ,
151165 end_token_id = None ,
152166 pad_token_id = 0 ,
153167):
@@ -160,11 +174,15 @@ def random_search(
160174
161175 Args:
162176 token_probability_fn: a callable, which takes in input_sequence
163- and output the probability distribution of the next token.
177+ and output the probability distribution of the next token. If
178+ `from_logits` set to True, it should output the logits of the next
179+ token.
164180 prompt: a list or a Tensor, can be 1D or 2D, the initial tokens to
165181 append generated tokens.
166182 max_length: int. The max length of generated text.
167183 seed: int, defaults to None. The random seed used for sampling.
184+ from_logits: bool. Indicates whether `token_probability_fn` outputs
185+ logits or probabilities.
168186 end_token_id: int, defaults to None. The token marking the end of the
169187 sequence, once encountered the generation is finished for the exact
170188 sequence. If None, every sequence is generated up to `max_length`.
@@ -179,8 +197,11 @@ def random_search(
179197
180198 Examples:
181199 ```python
200+ BATCH_SIZE = 8
182201 VOCAB_SIZE = 10
183202 FEATURE_SIZE = 16
203+ START_ID = 1
204+ END_ID = 2
184205
185206 # Create a dummy model to predict the next token.
186207 model = tf.keras.Sequential(
@@ -199,14 +220,15 @@ def random_search(
199220 def token_probability_fn(inputs):
200221 return model(inputs)[:, -1, :]
201222
202- prompt = tf.random.uniform(shape=[5, 5], maxval=VOCAB_SIZE, dtype=tf.int64 )
223+ prompt = tf.fill((BATCH_SIZE, 1), START_ID )
203224
204225 # Print the generated sequence (token ids).
205- keras_nlp.utils.random_sampling (
226+ keras_nlp.utils.random_search (
206227 token_probability_fn,
207228 prompt,
208229 max_length=10,
209- end_token_id=0,)
230+ end_token_id=END_ID,
231+ )
210232 ```
211233
212234 """
@@ -222,11 +244,14 @@ def token_probability_fn(inputs):
222244 input_is_1d = prompt .shape .rank == 1
223245 if input_is_1d :
224246 prompt = prompt [tf .newaxis , :]
247+ validate_token_probability_fn (token_probability_fn , prompt )
225248
226249 i = prompt .shape [1 ]
227250 while i < max_length :
228251 # If the prompt has reached our desired length, exit while loop.
229252 pred = token_probability_fn (prompt )
253+ if from_logits :
254+ pred = tf .keras .activations .softmax (pred , axis = - 1 )
230255 next_token = tf .cast (
231256 tf .random .categorical (tf .math .log (pred ), 1 , seed = seed ),
232257 dtype = prompt .dtype ,
@@ -242,3 +267,129 @@ def token_probability_fn(inputs):
242267 if input_is_1d :
243268 return tf .squeeze (prompt )
244269 return prompt
270+
271+
272+ def top_k_search (
273+ token_probability_fn ,
274+ prompt ,
275+ max_length ,
276+ k ,
277+ seed = None ,
278+ from_logits = False ,
279+ end_token_id = None ,
280+ pad_token_id = 0 ,
281+ ):
282+ """
283+ Text generation utility based on top-k sampling.
284+
285+ Top-k search samples the next token from the top-k tokens in the
286+ probability distribution provided by `token_probability_fn` and appends it
287+ to the existing sequence.
288+
289+ Args:
290+ token_probability_fn: a callable, which takes in input_sequence
291+ and output the probability distribution of the next token. If
292+ `from_logits` set to True, it should output the logits of the next
293+ token.
294+ prompt: a list or a Tensor, can be 1D or 2D, the initial tokens to
295+ append generated tokens.
296+ max_length: int. The max length of generated text.
297+ k: int. The number of top tokens to sample from. Should be non-negative
298+ and less than the vocabulary size.
299+ seed: int, defaults to None. The random seed used for sampling.
300+ from_logits: bool. Indicates whether `token_probability_fn` outputs
301+ logits or probabilities.
302+ end_token_id: int, defaults to None. The token marking the end of the
303+ sequence, once encountered the generation is finished for the exact
304+ sequence. If None, every sequence is generated up to `max_length`.
305+ If set, all tokens after encountering `end_token_id` will be
306+ replaced with `pad_token_id`.
307+ pad_token_id: int, defaults to 0. The pad token after `end_token_id`
308+ is received.
309+
310+ Returns:
311+ A 1D int Tensor, or 2D int Tensor representing the generated
312+ sequences.
313+
314+ Examples:
315+ ```python
316+ BATCH_SIZE = 8
317+ VOCAB_SIZE = 10
318+ FEATURE_SIZE = 16
319+ START_ID = 1
320+ END_ID = 2
321+
322+ # Create a dummy model to predict the next token.
323+ model = tf.keras.Sequential(
324+ [
325+ tf.keras.Input(shape=[None]),
326+ tf.keras.layers.Embedding(
327+ input_dim=VOCAB_SIZE,
328+ output_dim=FEATURE_SIZE,
329+ ),
330+ tf.keras.layers.Dense(VOCAB_SIZE, activation="softmax"),
331+ ]
332+ )
333+
334+ # Define a function that outputs the next token's probability given the
335+ # input sequence.
336+ def token_probability_fn(inputs):
337+ return model(inputs)[:, -1, :]
338+
339+ prompt = tf.fill((BATCH_SIZE, 1), START_ID)
340+
341+ # Print the generated sequence (token ids).
342+ keras_nlp.utils.top_k_search(
343+ token_probability_fn,
344+ prompt,
345+ max_length=10,
346+ k=4,
347+ end_token_id=END_ID,
348+ )
349+ ```
350+
351+ """
352+ if not tf .executing_eagerly ():
353+ raise RuntimeError (
354+ "`keras_nlp.utils.top_k_search` currently requires an eager "
355+ "execution context. Please call `top_k_search` outside "
356+ "tf.function or run `tf.config.run_functions_eagerly(True)` to run "
357+ "tf.function in eager mode."
358+ )
359+ if k <= 0 :
360+ raise ValueError ("k should be strictly positive (greater than 0)." )
361+
362+ prompt = validate_prompt (prompt )
363+ input_is_1d = prompt .shape .rank == 1
364+ if input_is_1d :
365+ prompt = prompt [tf .newaxis , :]
366+ validate_token_probability_fn (token_probability_fn , prompt )
367+
368+ i = prompt .shape [1 ]
369+ while i < max_length :
370+ # If the prompt has reached our desired length, exit while loop.
371+ pred = token_probability_fn (prompt )
372+ if from_logits :
373+ pred = tf .keras .activations .softmax (pred , axis = - 1 )
374+ # If k is greater than the vocabulary size, use the entire vocabulary.
375+ k = min (k , pred .shape [1 ])
376+ # Filter out top-k tokens.
377+ top_k_pred , top_k_indices = tf .math .top_k (pred , k = k )
378+ # Sample the next token from the probability distribution.
379+ next_token = tf .random .categorical (
380+ tf .math .log (top_k_pred ), 1 , seed = seed
381+ )
382+ # Rearrange to get the next token idx from the original order.
383+ next_token = tf .gather_nd (top_k_indices , next_token , batch_dims = 1 )
384+ next_token = tf .cast (next_token , dtype = prompt .dtype )
385+ # Append the next token to current sequence.
386+ prompt = tf .concat ([prompt , next_token [:, tf .newaxis ]], axis = - 1 )
387+ i += 1
388+
389+ if end_token_id is not None :
390+ prompt = mask_tokens_after_end_token (
391+ prompt , max_length , end_token_id , pad_token_id
392+ )
393+ if input_is_1d :
394+ return tf .squeeze (prompt )
395+ return prompt
0 commit comments