Skip to content

Commit 65349af

Browse files
authored
Random Sampling Util for Text Generation (#228)
* reformatted greedy search with helper functions and added random sampling util * reformat files * split testing into two classes + minor changes * formatted code * naming changes * naming changes * format changes * naming changes to random_search * removed docstring for helper and added random_search to init
1 parent d32a47e commit 65349af

File tree

3 files changed

+250
-19
lines changed

3 files changed

+250
-19
lines changed

keras_nlp/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@
1313
# limitations under the License.
1414

1515
from keras_nlp.utils.text_generation import greedy_search
16+
from keras_nlp.utils.text_generation import random_search

keras_nlp/utils/text_generation.py

Lines changed: 136 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,35 @@
1717
import tensorflow as tf
1818

1919

20+
def validate_prompt(prompt):
21+
"""
22+
Helper function to validate input to text_generation utils.
23+
"""
24+
if isinstance(prompt, tf.RaggedTensor):
25+
raise ValueError(
26+
"RaggedTensor `prompt` is not supported, please "
27+
"provide `prompt` as a list or Tensor."
28+
)
29+
if not isinstance(prompt, tf.Tensor):
30+
prompt = tf.convert_to_tensor(prompt)
31+
return prompt
32+
33+
34+
def mask_tokens_after_end_token(prompt, max_length, end_token_id, pad_token_id):
35+
"""
36+
Helper function to mask the tokens after the end token.
37+
"""
38+
# Mask out tokens after `end_token_id` is encountered.
39+
# Find index of first end_token_id.
40+
end_indices = tf.math.argmax(prompt == end_token_id, -1)
41+
# Use max_length if no `end_token_id` is found.
42+
end_indices = tf.where(end_indices == 0, max_length, end_indices)
43+
# Build a mask including end_token and replace tokens after end_token
44+
# with `pad_token_id`.
45+
valid_indices = tf.sequence_mask(end_indices + 1, maxlen=max_length)
46+
return tf.where(valid_indices, prompt, pad_token_id)
47+
48+
2049
def greedy_search(
2150
token_probability_fn,
2251
prompt,
@@ -88,13 +117,9 @@ def token_probability_fn(inputs):
88117
"tf.function or run `tf.config.run_functions_eagerly(True)` to run "
89118
"tf.function in eager mode."
90119
)
91-
if isinstance(prompt, tf.RaggedTensor):
92-
raise ValueError(
93-
"RaggedTensor `prompt` is not supported, please "
94-
"provide `prompt` as a list or Tensor."
95-
)
96-
if not isinstance(prompt, tf.Tensor):
97-
prompt = tf.convert_to_tensor(prompt)
120+
121+
prompt = validate_prompt(prompt)
122+
98123
input_is_1d = prompt.shape.rank == 1
99124
if input_is_1d:
100125
prompt = prompt[tf.newaxis, :]
@@ -109,16 +134,111 @@ def token_probability_fn(inputs):
109134
i += 1
110135

111136
if end_token_id is not None:
112-
# Mask out tokens after `end_token_id` is encountered.
113-
# Find index of first end_token_id.
114-
end_indices = tf.math.argmax(prompt == end_token_id, -1)
115-
# Use max_length if no `end_token_id` is found.
116-
end_indices = tf.where(end_indices == 0, max_length, end_indices)
117-
# Build a mask including end_token and replace tokens after end_token
118-
# with `pad_token_id`.
119-
valid_indices = tf.sequence_mask(end_indices + 1, maxlen=max_length)
120-
prompt = tf.where(valid_indices, prompt, pad_token_id)
137+
prompt = mask_tokens_after_end_token(
138+
prompt, max_length, end_token_id, pad_token_id
139+
)
121140

122141
if input_is_1d:
123142
return tf.squeeze(prompt)
124143
return prompt
144+
145+
146+
def random_search(
147+
token_probability_fn,
148+
prompt,
149+
max_length,
150+
seed=None,
151+
end_token_id=None,
152+
pad_token_id=0,
153+
):
154+
"""
155+
Text generation utility based on randomly sampling the entire probability
156+
distribution.
157+
158+
Random sampling samples the next token from the probability distribution
159+
provided by `token_probability_fn` and appends it to the existing sequence.
160+
161+
Args:
162+
token_probability_fn: a callable, which takes in input_sequence
163+
and output the probability distribution of the next token.
164+
prompt: a list or a Tensor, can be 1D or 2D, the initial tokens to
165+
append generated tokens.
166+
max_length: int. The max length of generated text.
167+
seed: int, defaults to None. The random seed used for sampling.
168+
end_token_id: int, defaults to None. The token marking the end of the
169+
sequence, once encountered the generation is finished for the exact
170+
sequence. If None, every sequence is generated up to `max_length`.
171+
If set, all tokens after encountering `end_token_id` will be
172+
replaced with `pad_token_id`.
173+
pad_token_id: int, defaults to 0. The pad token after `end_token_id`
174+
is received.
175+
176+
Returns:
177+
A 1D int Tensor, or 2D int Tensor representing the generated
178+
sequences.
179+
180+
Examples:
181+
```python
182+
VOCAB_SIZE = 10
183+
FEATURE_SIZE = 16
184+
185+
# Create a dummy model to predict the next token.
186+
model = tf.keras.Sequential(
187+
[
188+
tf.keras.Input(shape=[None]),
189+
tf.keras.layers.Embedding(
190+
input_dim=VOCAB_SIZE,
191+
output_dim=FEATURE_SIZE,
192+
),
193+
tf.keras.layers.Dense(VOCAB_SIZE, activation="softmax"),
194+
]
195+
)
196+
197+
# Define a function that outputs the next token's probability given the
198+
# input sequence.
199+
def token_probability_fn(inputs):
200+
return model(inputs)[:, -1, :]
201+
202+
prompt = tf.random.uniform(shape=[5, 5], maxval=VOCAB_SIZE, dtype=tf.int64)
203+
204+
# Print the generated sequence (token ids).
205+
keras_nlp.utils.random_sampling(
206+
token_probability_fn,
207+
prompt,
208+
max_length=10,
209+
end_token_id=0,)
210+
```
211+
212+
"""
213+
if not tf.executing_eagerly():
214+
raise RuntimeError(
215+
"`keras_nlp.utils.random_sampling` currently requires an eager "
216+
"execution context. Please call `random_sampling` outside "
217+
"tf.function or run `tf.config.run_functions_eagerly(True)` to run "
218+
"tf.function in eager mode."
219+
)
220+
221+
prompt = validate_prompt(prompt)
222+
input_is_1d = prompt.shape.rank == 1
223+
if input_is_1d:
224+
prompt = prompt[tf.newaxis, :]
225+
226+
i = prompt.shape[1]
227+
while i < max_length:
228+
# If the prompt has reached our desired length, exit while loop.
229+
pred = token_probability_fn(prompt)
230+
next_token = tf.cast(
231+
tf.random.categorical(tf.math.log(pred), 1, seed=seed),
232+
dtype=prompt.dtype,
233+
)
234+
# Append the next token to current sequence.
235+
prompt = tf.concat([prompt, next_token], axis=-1)
236+
i += 1
237+
238+
if end_token_id is not None:
239+
prompt = mask_tokens_after_end_token(
240+
prompt, max_length, end_token_id, pad_token_id
241+
)
242+
if input_is_1d:
243+
return tf.squeeze(prompt)
244+
return prompt

keras_nlp/utils/text_generation_test.py

Lines changed: 113 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@
1414
"""Tests for Text Generation Utils."""
1515

1616

17+
import numpy as np
1718
import tensorflow as tf
1819

1920
from keras_nlp.utils.text_generation import greedy_search
21+
from keras_nlp.utils.text_generation import random_search
2022

2123

22-
class TextGenerationTest(tf.test.TestCase):
24+
class GreedySearchTextGenerationTest(tf.test.TestCase):
2325
def setUp(self):
2426
super().setUp()
2527
vocab_size = 10
@@ -66,7 +68,7 @@ def test_generate_with_ragged_prompt(self):
6668
def test_assert_generation_is_correct(self):
6769
def token_probability_fn(inputs):
6870
batch_size = inputs.shape[0]
69-
prob = tf.constant([[0.1, 0.2, 0.3, 0.4]])
71+
prob = tf.constant([[0.01, 0.01, 0.08, 0.9]])
7072
return tf.repeat(prob, batch_size, axis=0)
7173

7274
batch_size = 10
@@ -82,7 +84,7 @@ def token_probability_fn(inputs):
8284
def test_end_token_id(self):
8385
def token_probability_fn(inputs):
8486
batch_size = inputs.shape[0]
85-
prob = tf.constant([[0.1, 0.2, 0.3, 0.4]])
87+
prob = tf.constant([[0.01, 0.01, 0.08, 0.9]])
8688
return tf.repeat(prob, batch_size, axis=0)
8789

8890
max_length = 5
@@ -96,5 +98,113 @@ def token_probability_fn(inputs):
9698
)
9799
expected_outputs = tf.tile([[3], [0]], [1, max_length - 2])
98100
expected_outputs = tf.concat([inputs, expected_outputs], axis=1)
101+
self.assertAllEqual(outputs, expected_outputs)
102+
103+
104+
class RandomSamplingTextGenerationTest(tf.test.TestCase):
105+
def setUp(self):
106+
super().setUp()
107+
vocab_size = 10
108+
feature_size = 16
109+
110+
# Create a dummy model to predict the next token.
111+
model = tf.keras.Sequential(
112+
[
113+
tf.keras.Input(shape=[None]),
114+
tf.keras.layers.Embedding(
115+
input_dim=vocab_size,
116+
output_dim=feature_size,
117+
),
118+
tf.keras.layers.Dense(vocab_size),
119+
tf.keras.layers.Softmax(),
120+
]
121+
)
122+
123+
def token_probability_fn(inputs):
124+
return model(inputs)[:, -1, :]
125+
126+
self.token_probability_fn = token_probability_fn
127+
128+
def test_generate_with_1d_prompt(self):
129+
inputs = tf.constant([1])
130+
outputs = random_search(self.token_probability_fn, inputs, max_length=5)
131+
self.assertEquals(outputs.shape, [5])
132+
133+
def test_generate_with_2d_prompt(self):
134+
inputs = tf.constant([[1], [1]])
135+
outputs = random_search(self.token_probability_fn, inputs, max_length=5)
136+
self.assertEquals(outputs.shape, [2, 5])
137+
138+
def test_generate_with_list_prompt(self):
139+
inputs = [[1], [1]]
140+
outputs = random_search(self.token_probability_fn, inputs, max_length=5)
141+
self.assertEquals(outputs.shape, [2, 5])
142+
143+
def test_generate_with_ragged_prompt(self):
144+
inputs = tf.ragged.constant([[1], [2, 3]])
145+
with self.assertRaises(ValueError):
146+
random_search(self.token_probability_fn, inputs, max_length=5)
147+
148+
def test_assert_seeded_generation_is_correct(self):
149+
def token_probability_fn(inputs):
150+
batch_size = inputs.shape[0]
151+
prob = tf.constant([[0.01, 0.01, 0.08, 0.9]])
152+
return tf.repeat(prob, batch_size, axis=0)
153+
154+
batch_size = 10
155+
inputs = 3 * tf.ones([batch_size, 1], dtype=tf.int32)
156+
max_length = 3
157+
tf.random.set_seed(42)
158+
outputs = random_search(
159+
token_probability_fn, inputs, max_length=max_length, seed=42
160+
)
161+
# Random sampling result with seed 42
162+
seeded_result = 3 * np.ones(shape=[batch_size, max_length])
163+
self.assertAllEqual(outputs, seeded_result)
164+
165+
def test_assert_probability_distribution_generation_is_correct(self):
166+
def token_probability_fn(inputs):
167+
batch_size = inputs.shape[0]
168+
prob = tf.constant([[0.01, 0.01, 0.08, 0.9]])
169+
return tf.repeat(prob, batch_size, axis=0)
170+
171+
batch_size = 10
172+
inputs = 3 * tf.ones([batch_size, 1], dtype=tf.int32)
173+
max_length = 3
174+
175+
outputs_count = np.array([0, 0, 0, 0])
176+
tf.random.set_seed(42)
177+
for i in range(500):
178+
outputs = random_search(
179+
token_probability_fn, inputs, max_length=max_length, seed=42
180+
)
181+
flatten_predictions = tf.reshape(outputs[:, 1:], [-1])
182+
for pred in flatten_predictions:
183+
outputs_count[pred] += 1
184+
self.assertAllClose(
185+
outputs_count / np.sum(outputs_count),
186+
[0.01, 0.01, 0.08, 0.9],
187+
rtol=0.2,
188+
)
189+
190+
def test_end_token_id(self):
191+
def token_probability_fn(inputs):
192+
batch_size = inputs.shape[0]
193+
prob = tf.constant([[0.01, 0.01, 0.08, 0.9]])
194+
return tf.repeat(prob, batch_size, axis=0)
99195

196+
max_length = 5
197+
inputs = tf.constant([[0, 1], [1, 2]])
198+
199+
outputs = random_search(
200+
token_probability_fn,
201+
inputs,
202+
max_length=max_length,
203+
seed=42,
204+
end_token_id=2,
205+
pad_token_id=0,
206+
)
207+
# Random sampling result with seed 42
208+
expected_outputs = tf.tile([[3], [0]], [1, max_length - 2])
209+
expected_outputs = tf.concat([inputs, expected_outputs], axis=1)
100210
self.assertAllEqual(outputs, expected_outputs)

0 commit comments

Comments
 (0)