Skip to content

Commit d817656

Browse files
Add FalconCausalLM (#1635)
* Add Falcon CausalLM. * Import FalconCausalLM in inits and uncomment a test. * Update generate_step function. * Pass padding_mask to call_with_cache. * Change int32 to int when casting. * Fix the keras2 test. * Cast attention_mask once. * Update docstrings. * Revert a change in an unrelated file! * Remove padding mask from call_with_cache. * Rename the mask variable in _build_alibi_tensor. * Remove endoftext from the end of the prompt in the example.
1 parent 0bf204d commit d817656

File tree

5 files changed

+491
-11
lines changed

5 files changed

+491
-11
lines changed

keras_nlp/api/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@
100100
from keras_nlp.src.models.f_net.f_net_preprocessor import FNetPreprocessor
101101
from keras_nlp.src.models.f_net.f_net_tokenizer import FNetTokenizer
102102
from keras_nlp.src.models.falcon.falcon_backbone import FalconBackbone
103+
from keras_nlp.src.models.falcon.falcon_causal_lm import FalconCausalLM
103104
from keras_nlp.src.models.falcon.falcon_causal_lm_preprocessor import (
104105
FalconCausalLMPreprocessor,
105106
)

keras_nlp/src/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@
9595
from keras_nlp.src.models.f_net.f_net_preprocessor import FNetPreprocessor
9696
from keras_nlp.src.models.f_net.f_net_tokenizer import FNetTokenizer
9797
from keras_nlp.src.models.falcon.falcon_backbone import FalconBackbone
98+
from keras_nlp.src.models.falcon.falcon_causal_lm import FalconCausalLM
9899
from keras_nlp.src.models.falcon.falcon_tokenizer import FalconTokenizer
99100
from keras_nlp.src.models.gemma.gemma_backbone import GemmaBackbone
100101
from keras_nlp.src.models.gemma.gemma_causal_lm import GemmaCausalLM
Lines changed: 290 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
# Copyright 2024 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from keras_nlp.src.api_export import keras_nlp_export
16+
from keras_nlp.src.backend import ops
17+
from keras_nlp.src.models.causal_lm import CausalLM
18+
from keras_nlp.src.models.falcon.falcon_backbone import FalconBackbone
19+
from keras_nlp.src.models.falcon.falcon_causal_lm_preprocessor import (
20+
FalconCausalLMPreprocessor,
21+
)
22+
from keras_nlp.src.utils.tensor_utils import any_equal
23+
24+
25+
@keras_nlp_export("keras_nlp.models.FalconCausalLM")
26+
class FalconCausalLM(CausalLM):
27+
"""An end-to-end Falcon model for causal language modeling.
28+
29+
A causal language model (LM) predicts the next token based on previous
30+
tokens. This task setup can be used to train the model unsupervised on
31+
plain text input, or to autoregressively generate plain text similar to
32+
the data used for training. This task can be used for pre-training or
33+
fine-tuning a Falcon model, simply by calling `fit()`.
34+
35+
This model has a `generate()` method, which generates text based on a
36+
prompt. The generation strategy used is controlled by an additional
37+
`sampler` argument on `compile()`. You can recompile the model with
38+
different `keras_nlp.samplers` objects to control the generation. By
39+
default, `"greedy"` sampling will be used.
40+
41+
This model can optionally be configured with a `preprocessor` layer, in
42+
which case it will automatically apply preprocessing to string inputs during
43+
`fit()`, `predict()`, `evaluate()` and `generate()`. This is done by default
44+
when creating the model with `from_preset()`.
45+
46+
Args:
47+
backbone: A `keras_nlp.models.FalconBackbone` instance.
48+
preprocessor: A `keras_nlp.models.FalconCausalLMPreprocessor` or `None`.
49+
If `None`, this model will not apply preprocessing, and inputs
50+
should be preprocessed before calling the model.
51+
52+
Examples:
53+
54+
Use `generate()` to do text generation.
55+
```python
56+
falcon_lm = keras_nlp.models.FalconCausalLM.from_preset("falcon_refinedweb_1b_en")
57+
falcon_lm.generate("I want to say", max_length=30)
58+
59+
# Generate with batched prompts.
60+
falcon_lm.generate(["This is a", "Where are you"], max_length=30)
61+
```
62+
63+
Compile the `generate()` function with a custom sampler.
64+
```python
65+
falcon_lm = keras_nlp.models.FalconCausalLM.from_preset("falcon_refinedweb_1b_en")
66+
falcon_lm.compile(sampler="top_k")
67+
falcon_lm.generate("I want to say", max_length=30)
68+
69+
falcon_lm.compile(sampler=keras_nlp.samplers.BeamSampler(num_beams=2))
70+
falcon_lm.generate("I want to say", max_length=30)
71+
```
72+
73+
Use `generate()` without preprocessing.
74+
```python
75+
prompt = {
76+
# Token ids for "<|endoftext|> Keras is".
77+
"token_ids": np.array([[50256, 17337, 292, 318]] * 2),
78+
# Use `"padding_mask"` to indicate values that should not be overridden.
79+
"padding_mask": np.array([[1, 1, 1, 1]] * 2),
80+
}
81+
82+
falcon_lm = keras_nlp.models.FalconCausalLM.from_preset(
83+
"falcon_refinedweb_1b_en",
84+
preprocessor=None,
85+
)
86+
falcon_lm.generate(prompt)
87+
```
88+
89+
Call `fit()` on a single batch.
90+
```python
91+
features = ["The quick brown fox jumped.", "I forgot my homework."]
92+
falcon_lm = keras_nlp.models.FalconCausalLM.from_preset("falcon_refinedweb_1b_en")
93+
falcon_lm.fit(x=features, batch_size=2)
94+
```
95+
96+
Call `fit()` without preprocessing.
97+
```python
98+
x = {
99+
# Token ids for "<|endoftext|> Keras is deep learning library<|endoftext|>"
100+
"token_ids": np.array([[50256, 17337, 292, 318, 2769, 4673, 5888, 50256, 0]] * 2),
101+
"padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 0]] * 2),
102+
}
103+
y = np.array([[17337, 292, 318, 2769, 4673, 5888, 50256, 0, 0]] * 2)
104+
sw = np.array([[1, 1, 1, 1, 1, 1, 1, 0, 0]] * 2)
105+
106+
falcon_lm = keras_nlp.models.FalconCausalLM.from_preset(
107+
"falcon_refinedweb_1b_en",
108+
preprocessor=None,
109+
)
110+
falcon_lm.fit(x=x, y=y, sample_weight=sw, batch_size=2)
111+
```
112+
113+
Custom backbone and vocabulary.
114+
```python
115+
vocab = {"<|endoftext|>": 0, "a": 4, "Ġquick": 5, "Ġfox": 6}
116+
merges = ["Ġ q", "u i", "c k", "ui ck", "Ġq uick"]
117+
merges += ["Ġ f", "o x", "Ġf ox"]
118+
tokenizer = keras_nlp.models.FalconTokenizer(
119+
vocabulary=vocab,
120+
merges=merges,
121+
)
122+
preprocessor = keras_nlp.models.FalconCausalLMPreprocessor(
123+
tokenizer=tokenizer,
124+
sequence_length=128,
125+
)
126+
backbone = keras_nlp.models.FalconBackbone(
127+
vocabulary_size=50304,
128+
num_layers=24,
129+
num_attention_heads=64,
130+
hidden_dim=2048,
131+
intermediate_dim=4*2048,
132+
)
133+
falcon_lm = keras_nlp.models.FalconCausalLM(
134+
backbone=backbone,
135+
preprocessor=preprocessor,
136+
)
137+
falcon_lm.fit(x=features, batch_size=2)
138+
```
139+
"""
140+
141+
backbone_cls = FalconBackbone
142+
preprocessor_cls = FalconCausalLMPreprocessor
143+
144+
def __init__(
145+
self,
146+
backbone,
147+
preprocessor=None,
148+
**kwargs,
149+
):
150+
# === Layers ===
151+
self.backbone = backbone
152+
self.preprocessor = preprocessor
153+
154+
# === Functional Model ===
155+
inputs = backbone.input
156+
hidden_states = backbone(inputs)
157+
outputs = backbone.token_embedding(hidden_states, reverse=True)
158+
super().__init__(
159+
inputs=inputs,
160+
outputs=outputs,
161+
**kwargs,
162+
)
163+
164+
def call_with_cache(
165+
self,
166+
token_ids,
167+
cache,
168+
cache_update_index,
169+
):
170+
"""Forward pass of `FalconCausalLM` with cache.
171+
172+
`call_with_cache` adds an additional forward pass for the model for
173+
autoregressive inference. Unlike calling the model directly, this method
174+
allows caching previous key/value Tensors in multi-head attention layer,
175+
and avoids recomputing the outputs of seen tokens.
176+
177+
Args:
178+
token_ids: a dense int Tensor with shape `(batch_size, max_length)`.
179+
cache: a dense float Tensor, the cache of key and value.
180+
cache_update_index: int, or int Tensor. The index of current inputs in the
181+
whole sequence.
182+
183+
Returns:
184+
A (logits, hidden_states, cache) tuple. Where `logits` is the
185+
language model logits for the input token_ids, `hidden_states` is
186+
the final hidden representation of the input tokens, and `cache` is
187+
the decoding cache.
188+
"""
189+
x = self.backbone.token_embedding(token_ids)
190+
# Each decoder layer has a cache; we update them separately.
191+
caches = []
192+
for i, transformer_layer in enumerate(self.backbone.transformer_layers):
193+
current_cache = cache[:, i, ...]
194+
x, next_cache = transformer_layer(
195+
x,
196+
attention_cache=current_cache,
197+
attention_cache_update_index=cache_update_index,
198+
)
199+
caches.append(next_cache)
200+
cache = ops.stack(caches, axis=1)
201+
hidden_states = x = self.backbone.final_layernorm(x)
202+
logits = self.backbone.token_embedding(x, reverse=True)
203+
return logits, hidden_states, cache
204+
205+
def _build_cache(self, token_ids):
206+
"""Build an empty cache for use with `call_with_cache()`."""
207+
batch_size = ops.shape(token_ids)[0]
208+
max_length = ops.shape(token_ids)[1]
209+
num_layers = self.backbone.num_layers
210+
num_heads = self.backbone.num_attention_heads
211+
head_dim = self.backbone.hidden_dim // self.backbone.num_attention_heads
212+
shape = [batch_size, num_layers, 2, max_length, num_heads, head_dim]
213+
cache = ops.zeros(shape, dtype=self.compute_dtype)
214+
# Seed the cache.
215+
_, hidden_states, cache = self.call_with_cache(token_ids, cache, 0)
216+
return hidden_states, cache
217+
218+
def generate_step(
219+
self,
220+
inputs,
221+
stop_token_ids=None,
222+
):
223+
"""A compilable generation function for a single batch of inputs.
224+
225+
This function represents the inner, XLA-compilable, generation function
226+
for a single batch of inputs. Inputs should have the same structure as
227+
model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`.
228+
229+
Args:
230+
inputs: A dictionary with two keys `"token_ids"` and
231+
`"padding_mask"` and batched tensor values.
232+
stop_token_ids: Tuple of id's of end token's to stop on. If all
233+
sequences have produced a new stop token, generation
234+
will stop.
235+
"""
236+
token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"]
237+
# Create and seed cache with a single forward pass.
238+
hidden_states, cache = self._build_cache(token_ids)
239+
# Compute the lengths of all user inputted tokens ids.
240+
row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1)
241+
# Start at the first index that has no user inputted id.
242+
index = ops.min(row_lengths)
243+
244+
def next(prompt, cache, index):
245+
# The cache index is the index of our previous token.
246+
cache_update_index = index - 1
247+
batch_size = ops.shape(prompt)[0]
248+
prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1])
249+
logits, hidden_states, cache = self.call_with_cache(
250+
prompt,
251+
cache,
252+
cache_update_index,
253+
)
254+
return (
255+
ops.squeeze(logits, axis=1),
256+
ops.squeeze(hidden_states, axis=1),
257+
cache,
258+
)
259+
260+
token_ids = self.sampler(
261+
next=next,
262+
prompt=token_ids,
263+
cache=cache,
264+
index=index,
265+
mask=padding_mask,
266+
stop_token_ids=stop_token_ids,
267+
hidden_states=hidden_states,
268+
model=self,
269+
)
270+
271+
# Compute an output padding mask with the token ids we updated.
272+
if stop_token_ids is not None:
273+
# Build a mask of stop token locations not in the original
274+
# prompt (not in locations where `padding_mask` is True).
275+
end_locations = any_equal(
276+
token_ids, stop_token_ids, ops.logical_not(padding_mask)
277+
)
278+
end_locations = ops.cast(end_locations, "int32")
279+
# Use cumsum to get ones in all locations after end_locations.
280+
cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32")
281+
overflow = cumsum - end_locations
282+
# Our padding mask is the inverse of these overflow locations.
283+
padding_mask = ops.logical_not(ops.cast(overflow, "bool"))
284+
else:
285+
# Without early stopping, all locations will have been updated.
286+
padding_mask = ops.ones_like(token_ids, dtype="bool")
287+
return {
288+
"token_ids": token_ids,
289+
"padding_mask": padding_mask,
290+
}

0 commit comments

Comments
 (0)