Skip to content

Commit fa6b74d

Browse files
ai-edge-botcopybara-github
authored andcommitted
Re-author PaliGemma2
- model.forward() and model.generate() output the same results of original model. - PaliGemma2 use different masks in forward() depending on whether it's called from generate() or not. - use model_config.embedding_scale instead of hard-coded scale in Gemma2 as Gemma1. - define image_projection_scale separately from text embedding_scale because they are different in PaliGemma2. - simplify ReauthoredModelWrapper._forward_with_kv_cache() with keyword arguments - tflite generation and unittests will be in a following change. PiperOrigin-RevId: 708628768
1 parent eafcb12 commit fa6b74d

File tree

10 files changed

+398
-73
lines changed

10 files changed

+398
-73
lines changed

ai_edge_torch/generative/examples/gemma/gemma1.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,13 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
7272
pre_attention_norm_config=norm_config,
7373
post_attention_norm_config=norm_config,
7474
)
75+
embedding_dim = 2048
7576
config = cfg.ModelConfig(
7677
vocab_size=256000,
7778
num_layers=18,
7879
max_seq_len=8192,
79-
embedding_dim=2048,
80-
embedding_scale=2048**0.5,
80+
embedding_dim=embedding_dim,
81+
embedding_scale=embedding_dim**0.5,
8182
kv_cache_max_len=kv_cache_max_len,
8283
block_configs=block_config,
8384
final_norm_config=norm_config,

ai_edge_torch/generative/examples/gemma/gemma2.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
"""Example of building a Gemma2 model."""
1717

18-
from typing import Optional, Tuple
18+
from typing import List, Optional, Tuple
1919

2020
from ai_edge_torch.generative.layers import attention
2121
from ai_edge_torch.generative.layers import builder
@@ -136,29 +136,45 @@ def forward(
136136
f"Cannot forward sequence of length {seq_len}, max seq length is only"
137137
f" {self.config.max_seq_len}"
138138
)
139-
assert len(self.transformer_blocks) == len(kv_cache.caches), (
140-
"The number of transformer blocks and the number of KV cache entries"
141-
" must be the same."
142-
)
143139

140+
# token embeddings of shape (b, t, n_embd)
141+
input_embeds = self.tok_embedding(tokens)
144142
# RoPE parameters are the same for all blocks. Use the first layer.
145143
attn_config = self.config.block_config(0).attn_config
146144
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
147145
rope = rotary_pos_emb.build_rope(
148146
input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
149147
)
148+
mask = [self.get_attention_mask(
149+
self.config.block_config(i).attn_config.attn_type, input_pos
150+
) for i in range(self.config.num_layers)]
150151

151-
# token embeddings of shape (b, t, n_embd)
152-
x = self.tok_embedding(tokens)
153-
x = x * (self.config.embedding_dim**0.5)
152+
return self._forward_with_embeds(
153+
input_embeds, rope, mask, input_pos, kv_cache, export_config
154+
)
155+
156+
def _forward_with_embeds(
157+
self,
158+
input_embeds: torch.Tensor,
159+
rope: Tuple[torch.Tensor, torch.Tensor],
160+
mask: List[torch.Tensor],
161+
input_pos: torch.Tensor,
162+
kv_cache: kv_utils.KVCache,
163+
export_config: Optional[model_builder.ExportConfig] = None,
164+
) -> dict[torch.Tensor, kv_utils.KVCache]:
165+
"""Forwards the model with input embeddings."""
166+
assert len(self.transformer_blocks) == len(kv_cache.caches), (
167+
"The number of transformer blocks and the number of KV cache entries"
168+
" must be the same."
169+
)
154170

171+
if self.config.embedding_scale is not None:
172+
input_embeds = input_embeds * self.config.embedding_scale
173+
x = input_embeds
155174
updated_kv_entries = []
156175
for i, block in enumerate(self.transformer_blocks):
157-
mask = self.get_attention_mask(
158-
block.config.attn_config.attn_type, input_pos
159-
)
160176
kv_entry = kv_cache.caches[i] if kv_cache else None
161-
x, kv_entry = block(x, rope, mask, input_pos, kv_entry)
177+
x, kv_entry = block(x, rope, mask[i], input_pos, kv_entry)
162178
if kv_entry:
163179
updated_kv_entries.append(kv_entry)
164180
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
@@ -227,11 +243,13 @@ def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
227243
)
228244

229245
num_layers = 26
246+
embedding_dim = 2304
230247
config = cfg.ModelConfig(
231248
vocab_size=256000,
232249
num_layers=num_layers,
233250
max_seq_len=8192,
234-
embedding_dim=2304,
251+
embedding_dim=embedding_dim,
252+
embedding_scale=embedding_dim**0.5,
235253
kv_cache_max_len=kv_cache_max_len,
236254
block_configs=[get_block_config(i) for i in range(num_layers)],
237255
final_norm_config=norm_config,
@@ -248,6 +266,7 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
248266
config.num_layers = 2
249267
config.max_seq_len = 2 * kv_cache_max_len
250268
config.embedding_dim = 128
269+
config.embedding_scale = config.embedding_dim**0.5
251270
config.block_configs = config.block_configs[: config.num_layers]
252271
for block_config in config.block_configs:
253272
block_config.attn_config.num_heads = 4

ai_edge_torch/generative/examples/paligemma/decoder.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def forward(
5555
kv_cache: kv_utils.KVCache,
5656
input_embeds: torch.Tensor = None,
5757
export_config: Optional[model_builder.ExportConfig] = None,
58+
called_by_generate: bool = True,
5859
) -> dict[torch.Tensor, kv_utils.KVCache]:
5960
if input_embeds is None:
6061
return super().forward(tokens, input_pos, kv_cache)
@@ -75,7 +76,7 @@ def forward(
7576
mask = torch.zeros(embeds_len, self.config.kv_cache_max)
7677
mask[:, embeds_len:] = float("-inf")
7778

78-
return self.forward_with_embeds(
79+
return self._forward_with_embeds(
7980
input_embeds, rope, mask, input_pos, kv_cache
8081
)
8182

@@ -113,12 +114,13 @@ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
113114
pre_attention_norm_config=norm_config,
114115
post_attention_norm_config=norm_config,
115116
)
117+
embedding_dim = 2048
116118
config = cfg.ModelConfig(
117119
vocab_size=257216,
118120
num_layers=18,
119121
max_seq_len=8192,
120-
embedding_dim=2048,
121-
embedding_scale=2048**0.5,
122+
embedding_dim=embedding_dim,
123+
embedding_scale=embedding_dim**0.5,
122124
kv_cache_max_len=kv_cache_max_len,
123125
block_configs=block_config,
124126
final_norm_config=norm_config,
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
# Copyright 2024 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
"""Example of building a decoder of PaliGemma2 3B model which is Gemma2."""
17+
18+
from typing import Optional
19+
20+
from ai_edge_torch.generative.examples.gemma import gemma2
21+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
22+
import ai_edge_torch.generative.layers.model_config as cfg
23+
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
24+
from ai_edge_torch.generative.utilities import model_builder
25+
import ai_edge_torch.generative.utilities.loader as loading_utils
26+
import torch
27+
28+
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
29+
ff_up_proj="language_model.model.layers.{}.mlp.up_proj",
30+
ff_down_proj="language_model.model.layers.{}.mlp.down_proj",
31+
ff_gate_proj="language_model.model.layers.{}.mlp.gate_proj",
32+
attn_query_proj="language_model.model.layers.{}.self_attn.q_proj",
33+
attn_key_proj="language_model.model.layers.{}.self_attn.k_proj",
34+
attn_value_proj="language_model.model.layers.{}.self_attn.v_proj",
35+
attn_output_proj="language_model.model.layers.{}.self_attn.o_proj",
36+
pre_attn_norm="language_model.model.layers.{}.input_layernorm",
37+
post_attn_norm="language_model.model.layers.{}.post_attention_layernorm",
38+
pre_ff_norm="language_model.model.layers.{}.pre_feedforward_layernorm",
39+
post_ff_norm="language_model.model.layers.{}.post_feedforward_layernorm",
40+
embedding="language_model.model.embed_tokens",
41+
final_norm="language_model.model.norm",
42+
lm_head=None,
43+
)
44+
45+
46+
class Decoder2(gemma2.Gemma2):
47+
"""A decoder of PaliGemma2 3B model which is Gemma2.
48+
49+
Besides a tensor of text token IDs, forward() can also take a tensor of
50+
embeddings which may include text or image or both.
51+
"""
52+
53+
@torch.inference_mode
54+
def forward(
55+
self,
56+
tokens: torch.Tensor,
57+
input_pos: torch.Tensor,
58+
kv_cache: kv_utils.KVCache,
59+
input_embeds: torch.Tensor = None,
60+
export_config: Optional[model_builder.ExportConfig] = None,
61+
called_by_generate: bool = True,
62+
) -> dict[torch.Tensor, kv_utils.KVCache]:
63+
if input_embeds is None:
64+
return super().forward(tokens, input_pos, kv_cache)
65+
66+
assert input_embeds is not None
67+
68+
repo_pos = input_pos + 1 # PaliGemma2 position is 1-based.
69+
# ROPE parameters for all attn_configs are the same. Take the first one.
70+
attn_config = self.config.block_config(0).attn_config
71+
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
72+
rope = rotary_pos_emb.build_rope(
73+
repo_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
74+
)
75+
76+
if called_by_generate:
77+
# PaliGemma2 generate() use a diagonal causal mask even with image embeds.
78+
mask = [self.get_attention_mask(
79+
self.config.block_config(i).attn_config.attn_type, input_pos
80+
) for i in range(self.config.num_layers)]
81+
else:
82+
# By default, don't mask image embeds with a diagonal causal mask.
83+
embeds_len = input_embeds.shape[1]
84+
mask = torch.zeros(embeds_len, self.config.kv_cache_max)
85+
mask[:, embeds_len:] = float("-inf")
86+
mask = [mask] * self.config.num_layers
87+
88+
return self._forward_with_embeds(
89+
input_embeds, rope, mask, input_pos, kv_cache, export_config
90+
)
91+
92+
93+
def get_decoder2_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
94+
"""Returns the model config for the decoder of a PaliGemma 3B model.
95+
96+
Args:
97+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
98+
is 1024.
99+
100+
Returns:
101+
The model config for the decoder of a PaliGemma 3B model.
102+
"""
103+
norm_config = cfg.NormalizationConfig(
104+
type=cfg.NormalizationType.RMS_NORM,
105+
epsilon=1e-6,
106+
zero_centered=True,
107+
)
108+
ff_config = cfg.FeedForwardConfig(
109+
type=cfg.FeedForwardType.GATED,
110+
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
111+
intermediate_size=9216,
112+
pre_ff_norm_config=norm_config,
113+
post_ff_norm_config=norm_config,
114+
)
115+
116+
def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
117+
attn_config = cfg.AttentionConfig(
118+
num_heads=8,
119+
head_dim=256,
120+
num_query_groups=4,
121+
rotary_base=10000,
122+
rotary_percentage=1.0,
123+
logit_softcap=50.0,
124+
sliding_window_size=4096,
125+
attn_type=(
126+
cfg.AttentionType.GLOBAL
127+
if idx % 2 == 0
128+
else cfg.AttentionType.LOCAL_SLIDING
129+
),
130+
)
131+
return cfg.TransformerBlockConfig(
132+
attn_config=attn_config,
133+
ff_config=ff_config,
134+
pre_attention_norm_config=norm_config,
135+
post_attention_norm_config=norm_config,
136+
)
137+
138+
num_layers = 26
139+
embedding_dim = 2304
140+
config = cfg.ModelConfig(
141+
vocab_size=257216,
142+
num_layers=num_layers,
143+
max_seq_len=8192,
144+
embedding_dim=embedding_dim,
145+
embedding_scale=embedding_dim**0.5,
146+
kv_cache_max_len=kv_cache_max_len,
147+
block_configs=[get_block_config(i) for i in range(num_layers)],
148+
final_norm_config=norm_config,
149+
lm_head_use_bias=False,
150+
enable_hlfb=True,
151+
final_logit_softcap=30.0,
152+
)
153+
return config
154+
155+
156+
def get_fake_decoder2_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
157+
config = get_decoder2_config(kv_cache_max_len)
158+
# PaliGemma2 decoder has only one block config.
159+
config.block_config(0).ff_config.intermediate_size = 128
160+
config.vocab_size = 128
161+
config.num_layers = 2
162+
config.max_seq_len = 2 * kv_cache_max_len
163+
return config
164+
165+
166+
def build_decoder2(checkpoint_path: str, **kwargs) -> torch.nn.Module:
167+
return model_builder.build_decoder_only_model(
168+
checkpoint_path=checkpoint_path,
169+
config=get_decoder2_config(**kwargs),
170+
tensor_names=TENSOR_NAMES,
171+
model_class=Decoder2,
172+
)

0 commit comments

Comments
 (0)