diff --git a/ai_edge_torch/generative/examples/embedding_gemma/convert_to_tflite.py b/ai_edge_torch/generative/examples/embedding_gemma/convert_to_tflite.py new file mode 100644 index 00000000..3c0a4010 --- /dev/null +++ b/ai_edge_torch/generative/examples/embedding_gemma/convert_to_tflite.py @@ -0,0 +1,59 @@ +# Copyright 2025 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Example of converting EmbeddingGemma-300M model to TFLite.""" + +import os + +from absl import app +import ai_edge_torch as at +from ai_edge_torch.generative.examples.embedding_gemma import embedding_gemma +from ai_edge_torch.generative.utilities import converter as generative_converter +import torch + +flags = generative_converter.define_conversion_flags( + model_name="embedding_gemma" +) +FLAGS = flags.FLAGS + + +def main(_): + model = embedding_gemma.build_model(FLAGS.checkpoint_path) + model.eval() + seq_len = max(FLAGS.prefill_seq_lens) + + sample_inputs = ( + torch.ones(1, seq_len, dtype=torch.long), # tokens + torch.ones(1, seq_len, dtype=torch.long), # attention_mask + ) + + quant_config = generative_converter.get_quant_recipe_from_flag( + FLAGS.quantize, model.config + ) + edge_model = at.convert( + model, + sample_inputs, + quant_config=quant_config, + ) + + output_dir = FLAGS.output_path + quant_suffix = generative_converter.create_quantize_suffix(FLAGS.quantize) + output_filename = f"{FLAGS.output_name_prefix}_{quant_suffix}.tflite" + output_path = os.path.join(output_dir, output_filename) + edge_model.export(output_path) + print(f"TFLite model successfully saved to {output_path}") + + +if __name__ == "__main__": + app.run(main) diff --git a/ai_edge_torch/generative/examples/embedding_gemma/embedding_gemma.py b/ai_edge_torch/generative/examples/embedding_gemma/embedding_gemma.py new file mode 100644 index 00000000..a8a72c44 --- /dev/null +++ b/ai_edge_torch/generative/examples/embedding_gemma/embedding_gemma.py @@ -0,0 +1,214 @@ +# Copyright 2025 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""EmbeddingGemma-300M model implementation.""" + +import math +import os +from typing import Callable, Dict + +from ai_edge_torch.generative.layers import attention +from ai_edge_torch.generative.layers import attention_utils +from ai_edge_torch.generative.layers import normalization as norm +import ai_edge_torch.generative.layers.model_config as cfg +from ai_edge_torch.generative.utilities import loader +from safetensors.torch import load_file +import torch +from torch import nn + + +class EmbeddingGemma(nn.Module): + """EmbeddingGemma-300M model implementation.""" + + def __init__(self, config: cfg.ModelConfig): + super().__init__() + self.config = config + self.embedder = nn.Embedding( + config.vocab_size, config.embedding_dim, padding_idx=0 + ) + self.transformer_blocks = nn.ModuleList([ + attention.TransformerBlock(block_config, config) + for block_config in config.block_configs + ]) + self.dense1 = nn.Linear(config.embedding_dim, 3072, bias=False) + self.dense2 = nn.Linear(3072, config.embedding_dim, bias=False) + + def _prepare_attention_mask(self, attention_mask, input_shape, dtype, device): + """Creates a padding attention mask.""" + batch_size, seq_len = input_shape + if attention_mask is None: + return torch.zeros((batch_size, 1, 1, seq_len), dtype=dtype, device=device) + padding_mask = torch.where( + attention_mask == 0, torch.finfo(dtype).min, 0.0 + ) + return padding_mask[:, None, None, :] + + def mean_pool(self, last_hidden_states, attention_mask): + """Mean pooling of hidden states, ignoring padding tokens.""" + masked_hidden_states = last_hidden_states * attention_mask.unsqueeze(-1) + sum_hidden_states = masked_hidden_states.sum(dim=1) + count = attention_mask.sum(dim=1).unsqueeze(-1) + return sum_hidden_states / (count + 1e-9) + + def forward( + self, tokens: torch.Tensor, attention_mask: torch.Tensor | None = None + ) -> torch.Tensor: + batch_size, seq_len = tokens.shape + if attention_mask is None: + attention_mask = torch.ones(batch_size, seq_len, device=tokens.device) + + x = self.embedder(tokens) + x = x * math.sqrt(self.config.embedding_dim) + + positions = torch.arange(0, seq_len, device=tokens.device) + attn_mask = self._prepare_attention_mask( + attention_mask, (batch_size, seq_len), x.dtype, x.device + ) + rope_cos, rope_sin = attention_utils.build_rope_cache( + size=self.config.max_seq_len, + dim=self.config.block_configs[0].attn_config.head_dim, + base=self.config.block_configs[0].attn_config.rotary_base, + dtype=x.dtype, + device=x.device, + ) + rope = (rope_cos[positions], rope_sin[positions]) + + for block in self.transformer_blocks: + x = block(x, rope, attn_mask, kv_cache=None) + + pooled_x = self.mean_pool(x, attention_mask) + pooled_x = self.dense1(pooled_x) + pooled_x = self.dense2(pooled_x) + normalized_x = torch.nn.functional.normalize(pooled_x, p=2, dim=1) + return normalized_x + + +def get_model_config() -> cfg.ModelConfig: + """Returns the model config for EmbeddingGemma-300M.""" + attn_config = cfg.AttentionConfig( + num_heads=3, + head_dim=256, + num_query_groups=1, # MQA + rotary_base=1000000, + rotary_percentage=1.0, + ) + ff_config = cfg.FeedForwardConfig( + type=cfg.FeedForwardType.GATED, + activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH), + intermediate_size=1152, + ) + norm_config = cfg.NormalizationConfig( + type=cfg.NormalizationType.RMS_NORM, + epsilon=1e-6, + zero_centered=True, + ) + block_config = cfg.TransformerBlockConfig( + attn_config=attn_config, + ff_config=ff_config, + pre_attention_norm_config=norm_config, + post_attention_norm_config=norm_config, + parallel_residual=False, + ) + config = cfg.ModelConfig( + vocab_size=262144, + num_layers=24, + max_seq_len=8192, + embedding_dim=768, + block_configs=[block_config] * 24, + final_norm_config=norm_config, + ) + return config + + +def build_model( + checkpoint_path: str, + custom_loader: Callable[[str], Dict[str, torch.Tensor]] | None = None, +) -> nn.Module: + """Builds the EmbeddingGemma-300M model.""" + config = get_model_config() + model = EmbeddingGemma(config) + state_dict = {} + + has_sub_dirs = os.path.exists( + os.path.join(checkpoint_path, "2_Dense", "model.safetensors") + ) or os.path.exists( + os.path.join(checkpoint_path, "2_Dense", "pytorch_model.bin") + ) + + if has_sub_dirs: + try: + weights = loader.load_safetensors(checkpoint_path) + weights_dense1 = load_file( + os.path.join(checkpoint_path, "2_Dense", "model.safetensors") + ) + weights_dense2 = load_file( + os.path.join(checkpoint_path, "3_Dense", "model.safetensors") + ) + state_dict["dense1.weight"] = weights_dense1["linear.weight"] + state_dict["dense2.weight"] = weights_dense2["linear.weight"] + except Exception: + weights = loader.load_pytorch_statedict(checkpoint_path) + weights_dense1 = torch.load( + os.path.join(checkpoint_path, "2_Dense", "pytorch_model.bin") + ) + weights_dense2 = torch.load( + os.path.join(checkpoint_path, "3_Dense", "pytorch_model.bin") + ) + state_dict["dense1.weight"] = weights_dense1["linear.weight"] + state_dict["dense2.weight"] = weights_dense2["linear.weight"] + else: + try: + weights = loader.load_safetensors(checkpoint_path) + state_dict["dense1.weight"] = weights["dense1.weight"] + state_dict["dense2.weight"] = weights["dense2.weight"] + except Exception: + weights = loader.load_pytorch_statedict(checkpoint_path) + state_dict["dense1.weight"] = weights["dense1.weight"] + state_dict["dense2.weight"] = weights["dense2.weight"] + + state_dict["embedder.weight"] = weights["embed_tokens.weight"] + + for i in range(config.num_layers): + layer_prefix = f"layers.{i}" + tb_prefix = f"transformer_blocks.{i}" + # Norms + state_dict[f"{tb_prefix}.pre_atten_norm.weight"] = weights[ + f"{layer_prefix}.input_layernorm.weight" + ] + state_dict[f"{tb_prefix}.post_atten_norm.weight"] = weights[ + f"{layer_prefix}.post_attention_layernorm.weight" + ] + # Attention + q = weights[f"{layer_prefix}.self_attn.q_proj.weight"] + k = weights[f"{layer_prefix}.self_attn.k_proj.weight"] + v = weights[f"{layer_prefix}.self_attn.v_proj.weight"] + state_dict[f"{tb_prefix}.atten_func.qkv_projection.weight"] = torch.cat( + [q, k, v], dim=0 + ) + state_dict[f"{tb_prefix}.atten_func.output_projection.weight"] = weights[ + f"{layer_prefix}.self_attn.o_proj.weight" + ] + # Feed-forward + state_dict[f"{tb_prefix}.ff.w1.weight"] = weights[ + f"{layer_prefix}.mlp.gate_proj.weight" + ] + state_dict[f"{tb_prefix}.ff.w3.weight"] = weights[ + f"{layer_prefix}.mlp.up_proj.weight" + ] + state_dict[f"{tb_prefix}.ff.w2.weight"] = weights[ + f"{layer_prefix}.mlp.down_proj.weight" + ] + + model.load_state_dict(state_dict) + return model diff --git a/ai_edge_torch/generative/examples/embedding_gemma/verify.py b/ai_edge_torch/generative/examples/embedding_gemma/verify.py new file mode 100644 index 00000000..387f978a --- /dev/null +++ b/ai_edge_torch/generative/examples/embedding_gemma/verify.py @@ -0,0 +1,39 @@ +# Copyright 2025 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Verifies the reauthored EmbeddingGemma-300M model.""" + +from absl import app +from absl import flags + +from ai_edge_torch.generative.examples.embedding_gemma import verify_util + +_PROMPTS = flags.DEFINE_multi_string( + "prompts", + None, + "The input prompts to generate embeddings for.", +) +CHECKPOINT = "google/embeddinggemma-300m" + + +def main(_): + if not verify_util.verify_embedding_gemma_300m( + checkpoint_dir=CHECKPOINT, + prompts=_PROMPTS.value, + ): + exit(1) + + +if __name__ == "__main__": + app.run(main) diff --git a/ai_edge_torch/generative/examples/embedding_gemma/verify_util.py b/ai_edge_torch/generative/examples/embedding_gemma/verify_util.py new file mode 100644 index 00000000..959d7df5 --- /dev/null +++ b/ai_edge_torch/generative/examples/embedding_gemma/verify_util.py @@ -0,0 +1,100 @@ +# Copyright 2025 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Verification utilities for EmbeddingGemma-300M.""" + +from ai_edge_torch.generative.examples.embedding_gemma import embedding_gemma +from huggingface_hub import snapshot_download # pylint: disable=g-importing-member +from sentence_transformers import SentenceTransformer +import torch +import transformers + + +DEFAULT_PROMPTS = [ + "What is the meaning of life?", + "This is an example sentence.", +] + + +def _mean_pool(last_hidden_states, attention_mask): + """Mean pooling of hidden states, ignoring padding tokens.""" + masked_hidden_states = last_hidden_states * attention_mask.unsqueeze(-1) + sum_hidden_states = masked_hidden_states.sum(dim=1) + count = attention_mask.sum(dim=1).unsqueeze(-1) + count = torch.clamp(count, min=1e-9) + return sum_hidden_states / count + + +def verify_embedding_gemma_300m( + checkpoint_dir: str, + prompts: list[str] | None = None, + atol: float = 0.25, +) -> bool: + """Verifies EmbeddingGemma-300M.""" + try: + print(f"Downloading model from: {checkpoint_dir}") + model_path = snapshot_download(repo_id=checkpoint_dir) + print(f"Model downloaded to: {model_path}") + except Exception as e: + print(f"Error downloading model {checkpoint_dir}: {e}") + return False + + print(f"Loading tokenizer from: {model_path}") + tokenizer = transformers.AutoTokenizer.from_pretrained(model_path) + + print(f"Loading original model from: {model_path}") + try: + original_model = SentenceTransformer(model_path) + original_model.eval() + except Exception as e: + print(f"Failed to load original model: {e}") + return False + + print(f"Building reauthored model from: {model_path}") + try: + reauthored_model = embedding_gemma.build_model(model_path) + reauthored_model.eval() + except (OSError, ValueError) as e: + print(f"Failed to build or load reauthored model: {e}") + return False + + prompts_to_run = prompts if prompts is not None else DEFAULT_PROMPTS + print(f"Tokenizing prompts: {prompts_to_run}") + inputs = tokenizer( + prompts_to_run, return_tensors="pt", padding=True, truncation=True + ) + tokens, attention_mask = inputs["input_ids"], inputs["attention_mask"] + + print("Running inference...") + with torch.no_grad(): + # SentenceTransformer model directly returns embeddings + original_embeddings = original_model.encode( + prompts_to_run, convert_to_tensor=True + ) + print(f"Original embeddings shape: {original_embeddings.shape}") + # Reauthored model includes pooling and norm in forward pass + reauthored_embeddings = reauthored_model( + tokens, attention_mask=attention_mask + ) + + if not torch.allclose(original_embeddings, reauthored_embeddings, atol=atol): + print("Verification failed: Embeddings do not match!") + print( + "Max difference:" + f" {torch.max(torch.abs(original_embeddings - reauthored_embeddings))}" + ) + return False + + print("Verification successful: Embeddings match.") + return True diff --git a/ai_edge_torch/generative/examples/falcon/convert_to_tflite.py b/ai_edge_torch/generative/examples/falcon/convert_to_tflite.py new file mode 100644 index 00000000..65b8881d --- /dev/null +++ b/ai_edge_torch/generative/examples/falcon/convert_to_tflite.py @@ -0,0 +1,51 @@ +# Copyright 2025 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Example of converting Falcon-1B model to multi-signature tflite model.""" + +from absl import app +from ai_edge_torch.generative.examples.falcon import falcon +from ai_edge_torch.generative.utilities import converter +from ai_edge_torch.generative.utilities import export_config as export_cfg +from ai_edge_torch.generative.utilities import loader + +flags = converter.define_conversion_flags('falcon-1b') + +def main(_): + checkpoint_path = flags.FLAGS.checkpoint_path + pytorch_model = falcon.build_model( + checkpoint_path, + custom_loader=loader.maybe_get_custom_loader( + checkpoint_path, flags.FLAGS.custom_checkpoint_loader + ), + mask_cache_size=converter.get_mask_cache_size_from_flags(), + ) + + export_config = export_cfg.get_from_flags() + + converter.convert_to_tflite( + pytorch_model, + output_path=flags.FLAGS.output_path, + output_name_prefix=flags.FLAGS.output_name_prefix, + prefill_seq_len=flags.FLAGS.prefill_seq_lens, + kv_cache_max_len=flags.FLAGS.kv_cache_max_len, + quantize=flags.FLAGS.quantize, + lora_ranks=flags.FLAGS.lora_ranks, + export_config=export_config, + ) + + +if __name__ == '__main__': + app.run(main) diff --git a/ai_edge_torch/generative/examples/falcon/falcon.py b/ai_edge_torch/generative/examples/falcon/falcon.py new file mode 100644 index 00000000..5cbbed5f --- /dev/null +++ b/ai_edge_torch/generative/examples/falcon/falcon.py @@ -0,0 +1,106 @@ +# Copyright 2025 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Falcon-1B model implementation.""" + +from typing import Callable, Dict +import ai_edge_torch.generative.layers.model_config as cfg +from ai_edge_torch.generative.utilities import loader +from ai_edge_torch.generative.utilities import model_builder +import torch +from torch import nn + +TENSOR_NAMES = loader.ModelLoader.TensorNames( + embedding="transformer.word_embeddings", + final_norm="transformer.ln_f", + pre_attn_norm="transformer.h.{}.input_layernorm", + post_attn_norm="transformer.h.{}.post_attention_layernorm", + attn_fused_qkv_proj="transformer.h.{}.self_attention.query_key_value", + attn_output_proj="transformer.h.{}.self_attention.dense", + ff_up_proj="transformer.h.{}.mlp.dense_h_to_4h", + ff_down_proj="transformer.h.{}.mlp.dense_4h_to_h", + lm_head="lm_head", +) + + +class Falcon(model_builder.DecoderOnlyModel): + """A Falcon-1B model built from the Edge Generative API layers.""" + pass + + +def get_model_config() -> cfg.ModelConfig: + """Returns the model config for a Falcon-1B model.""" + attn_config = cfg.AttentionConfig( + num_heads=32, + head_dim=64, + num_query_groups=32, # Multi-Head Attention + use_alibi=True, + rotary_percentage=0.0, + qkv_use_bias=True, + output_proj_use_bias=True, + ) + ff_config = cfg.FeedForwardConfig( + type=cfg.FeedForwardType.SEQUENTIAL, # Falcon uses a standard MLP + activation=cfg.ActivationConfig(cfg.ActivationType.GELU), + intermediate_size=8192, # 4 * embedding_dim + use_bias=True, + ) + norm_config = cfg.NormalizationConfig( + type=cfg.NormalizationType.LAYER_NORM, epsilon=1e-5, use_bias=True + ) + block_config = cfg.TransformerBlockConfig( + attn_config=attn_config, + ff_config=ff_config, + pre_attention_norm_config=norm_config, + post_attention_norm_config=norm_config, + parallel_residual=False, # parallel_attn=False in config + ) + config = cfg.ModelConfig( + vocab_size=50304, + num_layers=24, + max_seq_len=2048, + embedding_dim=2048, + block_configs=[block_config] * 24, # All layers are the same + final_norm_config=norm_config, + ) + return config + + +def get_fake_model_config() -> cfg.ModelConfig: + config = get_model_config() + config.vocab_size = 128 + config.num_layers = 2 + config.block_configs[0].ff_config.intermediate_size = 64 + return config + + +def build_model( + checkpoint_path: str, + custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None, + mask_cache_size: int = 0, +) -> nn.Module: + """Builds the Falcon-1B model.""" + # TODO(adisr): Confirm tensor names for Falcon-1B. + # Using default TENSOR_NAMES from model_builder for now. + return model_builder.build_decoder_only_model( + checkpoint_path=checkpoint_path, + config=get_model_config(), + # Uncomment for testing with a smaller model + # config=get_fake_model_config(), + tensor_names=TENSOR_NAMES, + model_class=Falcon, + custom_loader=custom_loader, + mask_cache_size=mask_cache_size, + ) diff --git a/ai_edge_torch/generative/examples/falcon/verify.py b/ai_edge_torch/generative/examples/falcon/verify.py new file mode 100644 index 00000000..2bc4cdd8 --- /dev/null +++ b/ai_edge_torch/generative/examples/falcon/verify.py @@ -0,0 +1,45 @@ +# Copyright 2025 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Verifies the reauthored Falcon-1B model.""" + +from absl import app +from absl import flags +from ai_edge_torch.generative.examples.falcon import verify_util + + +_PROMPTS = flags.DEFINE_multi_string( + "prompts", + "What is the meaning of life?", + "The input prompts to generate answers.", +) +_MAX_NEW_TOKENS = flags.DEFINE_integer( + "max_new_tokens", + 30, + "The maximum size of the generated tokens.", +) +_CHECKPOINT = "tiiuae/falcon-rw-1b" + + +def main(_): + verify_util.verify_falcon_1b( + checkpoint_dir=_CHECKPOINT, + max_new_tokens=_MAX_NEW_TOKENS.value, + prompts=_PROMPTS.value, + ) + + +if __name__ == "__main__": + app.run(main) diff --git a/ai_edge_torch/generative/examples/falcon/verify_util.py b/ai_edge_torch/generative/examples/falcon/verify_util.py new file mode 100644 index 00000000..08d69f98 --- /dev/null +++ b/ai_edge_torch/generative/examples/falcon/verify_util.py @@ -0,0 +1,76 @@ +# Copyright 2025 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utils for verifying the Falcon-1B model.""" +import logging +import os +import pathlib +from typing import Callable, Dict +from ai_edge_torch.generative.examples.falcon import falcon +from ai_edge_torch.generative.utilities import loader +from ai_edge_torch.generative.utilities import transformers_verifier +from ai_edge_torch.generative.utilities import verifier +import torch +import transformers + +DEFAULT_PROMPTS = ["What is the meaning of life?"] + + +def verify_falcon_1b( + checkpoint_dir: str, + weight_filename: str = "model.safetensors", + max_new_tokens: int = 30, + prompts: list[str] | None = None, + initialize_from_local: bool = True, + custom_loader: Callable[[str], Dict[str, torch.Tensor]] | None = None, +) -> bool: + """Verifies the reauthored Falcon-1B model with a custom loader.""" + logging.info("Loading the original model from: %s", checkpoint_dir) + original_model = transformers.AutoModelForCausalLM.from_pretrained( + checkpoint_dir + ) + + logging.info("Building the reauthored model from: %s", checkpoint_dir) + + if custom_loader is None and not initialize_from_local: + custom_loader = loader.get_custom_loader("", "safetensors") + + if initialize_from_local: + # Locate the cached dir. + cached_config_file = transformers.utils.cached_file( + checkpoint_dir, transformers.utils.CONFIG_NAME + ) + reauthored_checkpoint = pathlib.Path(cached_config_file).parent + else: + reauthored_checkpoint = os.path.join(checkpoint_dir, weight_filename) + + logging.info("Building the reauthored model from: %s", reauthored_checkpoint) + reauthored_model = falcon.build_model( + checkpoint_path=reauthored_checkpoint, + custom_loader=custom_loader, + mask_cache_size=verifier.DEFAULT_KV_CACHE_MAX_LEN, + ) + + logging.info("Loading the tokenizer from: %s", checkpoint_dir) + tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint_dir) + return verifier.verify_reauthored_model( + original_model=transformers_verifier.TransformersModelWrapper( + original_model + ), + reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model), + tokenizer=verifier.TokenizerWrapper(tokenizer), + generate_prompts=DEFAULT_PROMPTS if prompts is None else prompts, + max_new_tokens=max_new_tokens, + atol=1e-3, + ) diff --git a/ai_edge_torch/generative/layers/attention.py b/ai_edge_torch/generative/layers/attention.py index efe8de5c..e56ebd85 100644 --- a/ai_edge_torch/generative/layers/attention.py +++ b/ai_edge_torch/generative/layers/attention.py @@ -18,6 +18,7 @@ import abc from typing import Optional, Tuple, Union +from ai_edge_torch.generative.layers import attention_utils from ai_edge_torch.generative.layers import builder from ai_edge_torch.generative.layers import kv_cache as kv_utils from ai_edge_torch.generative.layers import lora as lora_utils @@ -240,13 +241,32 @@ def forward( k = k.reshape(B, T, -1, self.config.head_dim) v = v.reshape(B, T, -1, self.config.head_dim) - if rope is not None: + alibi_bias = None + if self.config.use_alibi: + k_size = T + if mask is not None: + k_size = mask.shape[-1] + elif input_pos is not None: + # If mask is not present, assume current sequence length is key length. + k_size = input_pos[-1].item() + 1 + alibi_bias = attention_utils.build_alibi_bias( + self.config.num_heads, T, k_size, dtype=x.dtype, device=x.device + ) + elif rope is not None: # Compute rotary positional embedding for query and key. cos, sin = rope q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin) sdpa_out, kv_cache = sdpa_with_kv_update.sdpa_with_kv_update( - q, k, v, kv_cache, input_pos, mask, self.config, self.enable_hlfb + q, + k, + v, + kv_cache, + input_pos, + mask, + self.config, + self.enable_hlfb, + alibi_bias=alibi_bias, ) # Compute the output projection. diff --git a/ai_edge_torch/generative/layers/attention_utils.py b/ai_edge_torch/generative/layers/attention_utils.py index b9383d1e..e960d87a 100644 --- a/ai_edge_torch/generative/layers/attention_utils.py +++ b/ai_edge_torch/generative/layers/attention_utils.py @@ -15,11 +15,75 @@ # Common utility functions used with attention module. import math -from typing import Tuple +from typing import List, Tuple import torch +def _get_alibi_slopes(n_heads: int) -> List[float]: + """Returns slopes for ALiBi implementation. + + The slopes are taken from the ALiBi paper + [https://arxiv.org/abs/2108.12409]. + The slopes are later used to calculate the bias which is added to the + attention scores. + + Args: + n_heads (int): The number of attention heads. + """ + + def get_slopes_power_of_2(n): + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(n_heads).is_integer(): + return get_slopes_power_of_2(n_heads) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(n_heads)) + return ( + get_slopes_power_of_2(closest_power_of_2) + + _get_alibi_slopes(2 * closest_power_of_2)[0::2][ + : n_heads - closest_power_of_2 + ] + ) + + +def build_alibi_bias( + n_heads: int, + q_size: int, + k_size: int, + dtype: torch.dtype = torch.float32, + device: torch.device = None, +) -> torch.Tensor: + """Builds ALiBi bias tensor based on key position. + + The bias tensor is added to the attention scores before softmax. + Replicates HuggingFace Falcon implementation behavior where bias only depends + on key position j, not relative position j-i. + + Args: + n_heads (int): The number of attention heads. + q_size (int): The query size of the bias tensor. + k_size (int): The key size of the bias tensor. + dtype (torch.dtype, optional): Output tensor's data type. Defaults to + torch.float32. + device (torch.device, optional): Output tensor's data type. Defaults to + None in which case "cpu" is used. + + Returns: + torch.Tensor: The ALiBi bias tensor of shape (1, n_heads, 1, k_size). + """ + if device is None: + device = torch.device('cpu') + slopes = torch.tensor(_get_alibi_slopes(n_heads), dtype=dtype, device=device) + k_pos = torch.arange(k_size, device=device) + # According to HF implementation, bias only depends on key position. + # slopes[h] * k_pos[j] + alibi_bias = slopes.unsqueeze(-1) * k_pos.unsqueeze(0) # Shape: H, K + return alibi_bias[None, :, None, :].to(dtype) + + def build_rope_cache( size: int, dim: int, diff --git a/ai_edge_torch/generative/layers/builder.py b/ai_edge_torch/generative/layers/builder.py index ceaf6d92..bc889229 100644 --- a/ai_edge_torch/generative/layers/builder.py +++ b/ai_edge_torch/generative/layers/builder.py @@ -71,7 +71,7 @@ def build_norm( Raises: ValueError: If config's `layer_norm_type` is not supported. """ - if config.type == cfg.NormalizationType.NONE: + if config is None or config.type == cfg.NormalizationType.NONE: return lambda x: x elif config.type == cfg.NormalizationType.RMS_NORM: return normalization.RMSNorm( @@ -84,7 +84,9 @@ def build_norm( init_fn=init_fn, ) elif config.type == cfg.NormalizationType.LAYER_NORM: - return normalization.LayerNorm(dim, config.epsilon, config.enable_hlfb) + return normalization.LayerNorm( + dim, config.epsilon, config.use_bias, config.enable_hlfb + ) elif config.type == cfg.NormalizationType.GROUP_NORM: return normalization.GroupNorm( config.group_num, dim, config.epsilon, config.enable_hlfb diff --git a/ai_edge_torch/generative/layers/model_config.py b/ai_edge_torch/generative/layers/model_config.py index 46a67ae6..cccd8f2f 100644 --- a/ai_edge_torch/generative/layers/model_config.py +++ b/ai_edge_torch/generative/layers/model_config.py @@ -75,6 +75,8 @@ class NormalizationConfig: scale_shift: float = 0.0 # Number of groups used in group normalization. group_num: Optional[float] = None + # Whether to use bias in norm. + use_bias: bool = True # Exprimental feature and may subject to change. @@ -108,6 +110,8 @@ class AttentionConfig: rotary_base: int = 10_000 # Percentage of Rotary Positional Embedding added Q and K projections. rotary_percentage: Optional[float] = None + # Whether to use ALiBi positional encoding. + use_alibi: bool = False # Whether to transpose the query groups of qkv bundled tensor before # splitting into separated tensors. qkv_transpose_before_split: bool = False diff --git a/ai_edge_torch/generative/layers/normalization.py b/ai_edge_torch/generative/layers/normalization.py index dfaa6e55..12f17c68 100644 --- a/ai_edge_torch/generative/layers/normalization.py +++ b/ai_edge_torch/generative/layers/normalization.py @@ -148,6 +148,7 @@ def __init__( self, dim: int, eps: float = 1e-5, + use_bias: bool = True, enable_hlfb: bool = False, ): """Initialize the LayerNorm layer. @@ -156,6 +157,7 @@ def __init__( dim (int): dimension of the input tensor. eps (float): A small float value to ensure numerical stability (default: 1e-5). + use_bias (bool): Whether to use bias in LayerNorm. enable_hlfb (bool): Whether to convert this normalization into a single op. """ @@ -164,7 +166,11 @@ def __init__( self.normalized_shape = (dim,) self.eps = eps self.weight = torch.nn.Parameter(torch.empty(dim), requires_grad=False) - self.bias = torch.nn.Parameter(torch.empty(dim), requires_grad=False) + self.bias = ( + torch.nn.Parameter(torch.empty(dim), requires_grad=False) + if use_bias + else None + ) def forward(self, x): """Running the forward pass of LayerNorm layer. @@ -175,7 +181,7 @@ def forward(self, x): Returns: torch.Tensor: output tensor after applying LayerNorm. """ - if self.enable_hlfb: + if self.enable_hlfb and self.bias is not None: return layer_norm_with_hlfb( x, self.normalized_shape, self.weight, self.bias, self.eps ) diff --git a/ai_edge_torch/generative/layers/scaled_dot_product_attention.py b/ai_edge_torch/generative/layers/scaled_dot_product_attention.py index 874aad52..424c3c45 100644 --- a/ai_edge_torch/generative/layers/scaled_dot_product_attention.py +++ b/ai_edge_torch/generative/layers/scaled_dot_product_attention.py @@ -32,6 +32,7 @@ def scaled_dot_product_attention( mask: Optional[torch.Tensor] = None, scale: Optional[float] = None, softcap: Optional[float] = None, + alibi_bias: Optional[torch.Tensor] = None, ): """Scaled dot product attention. @@ -41,14 +42,23 @@ def scaled_dot_product_attention( v (torch.Tensor): Value tensor, with shape [B, T, KV_LEN, H]. head_size (int): head dimension. mask (torch.Tensor): the optional mask tensor. + scale (float): the optional scale factor. + softcap (float): the optional softcap for the logits. + alibi_bias (torch.Tensor): optional alibi bias tensor. Returns: The output tensor of scaled_dot_product_attention. """ - if scale is None: scale = 1.0 / math.sqrt(head_size) + if alibi_bias is not None: + alibi_bias = alibi_bias * scale + if mask is None: + mask = alibi_bias + else: + mask = mask + alibi_bias + q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) @@ -72,7 +82,8 @@ def scaled_dot_product_attention( scores = scores / softcap scores = torch.tanh(scores) scores = scores * softcap - scores = scores + mask + if mask is not None: + scores = scores + mask out = F.softmax(scores.float(), dim=-1).type_as(q) y = torch.matmul(out, v) @@ -87,6 +98,7 @@ def scaled_dot_product_attention_with_hlfb( mask: Optional[torch.Tensor] = None, scale: Optional[float] = None, softcap: Optional[float] = None, + alibi_bias: Optional[torch.Tensor] = None, ): """Scaled dot product attention with high-level function boundary enabled. @@ -96,14 +108,23 @@ def scaled_dot_product_attention_with_hlfb( v (torch.Tensor): Value tensor, with shape [B, T, KV_LEN, H]. head_size (int): head dimension. mask (torch.Tensor): the optional mask tensor. + scale (float): the optional scale factor. + softcap (float): the optional softcap for the logits. + alibi_bias (torch.Tensor): optional alibi bias tensor. Returns: The output tensor of scaled_dot_product_attention. """ - if scale is None: scale = 1.0 / math.sqrt(head_size) + if alibi_bias is not None: + alibi_bias = alibi_bias * scale + if mask is None: + mask = alibi_bias + else: + mask = mask + alibi_bias + attrs = {"scale": scale} if softcap is not None: @@ -137,7 +158,8 @@ def scaled_dot_product_attention_with_hlfb( scores = scores / softcap scores = torch.tanh(scores) scores = scores * softcap - scores = scores + mask + if mask is not None: + scores = scores + mask out = F.softmax(scores.float(), dim=-1).type_as(q) y = torch.matmul(out, v) @@ -154,6 +176,7 @@ def scaled_dot_product_attention_transposed( mask: Optional[torch.Tensor] = None, scale: Optional[float] = None, softcap: Optional[float] = None, + alibi_bias: Optional[torch.Tensor] = None, ): """Scaled dot product attention with transposed key and value. @@ -165,14 +188,21 @@ def scaled_dot_product_attention_transposed( mask (torch.Tensor): the optional mask tensor. scale (float): the optional scale factor. softcap (float): the optional softcap for the logits. + alibi_bias (torch.Tensor): optional alibi bias tensor. Returns: The output tensor of scaled_dot_product_attention_transposed. """ - if scale is None: scale = 1.0 / math.sqrt(head_size) + if alibi_bias is not None: + alibi_bias = alibi_bias * scale + if mask is None: + mask = alibi_bias + else: + mask = mask + alibi_bias + query = query * scale assert mask is not None, "Mask should not be None!" diff --git a/ai_edge_torch/generative/layers/sdpa_with_kv_update.py b/ai_edge_torch/generative/layers/sdpa_with_kv_update.py index 38b5683f..c8dd5806 100644 --- a/ai_edge_torch/generative/layers/sdpa_with_kv_update.py +++ b/ai_edge_torch/generative/layers/sdpa_with_kv_update.py @@ -15,7 +15,7 @@ """Common utility functions for data loading etc.""" -from typing import Tuple +from typing import Optional, Tuple from ai_edge_torch.generative.layers import kv_cache as kv_utils from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa @@ -32,14 +32,15 @@ def sdpa_with_kv_update( mask: torch.Tensor, config: cfg.AttentionConfig, enable_hlfb: bool, + alibi_bias: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, kv_utils.KVCacheEntry]: """Wrapper function for scaled dot product attention with KV cache update.""" if kv is not None and kv.kv_layout == kv_utils.KV_LAYOUT_TRANSPOSED: return _sdpa_with_kv_update_transposed( - query, key, value, kv, input_pos, mask, config + query, key, value, kv, input_pos, mask, config, alibi_bias ) return _sdpa_with_kv_update_default( - query, key, value, kv, input_pos, mask, config, enable_hlfb + query, key, value, kv, input_pos, mask, config, enable_hlfb, alibi_bias ) @@ -51,6 +52,7 @@ def _sdpa_with_kv_update_transposed( input_pos: torch.Tensor, mask: torch.Tensor, config: cfg.AttentionConfig, + alibi_bias: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, kv_utils.KVCacheEntry]: # Transpose k/v to specific layout for GPU implementation. b, seq_len, n, h = query.shape @@ -77,6 +79,7 @@ def _sdpa_with_kv_update_transposed( config.head_dim, mask=mask, softcap=config.logit_softcap, + alibi_bias=alibi_bias, ) # 1, bk, gt, h sdpa_out = ( sdpa_out.reshape(b, -1, seq_len, h) @@ -95,6 +98,7 @@ def _sdpa_with_kv_update_default( mask: torch.Tensor, config: cfg.AttentionConfig, enable_hlfb: bool, + alibi_bias: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, kv_utils.KVCacheEntry]: b, seq_len, _, _ = query.shape if kv is not None: @@ -112,6 +116,7 @@ def _sdpa_with_kv_update_default( config.head_dim, mask=mask, softcap=config.logit_softcap, + alibi_bias=alibi_bias, ) sdpa_out = sdpa_out.reshape(b, seq_len, -1) return sdpa_out, kv diff --git a/ai_edge_torch/generative/utilities/loader.py b/ai_edge_torch/generative/utilities/loader.py index 63c177d8..9d66d006 100644 --- a/ai_edge_torch/generative/utilities/loader.py +++ b/ai_edge_torch/generative/utilities/loader.py @@ -135,7 +135,8 @@ def load_pytorch_statedict(full_path: str): tensors = {} for file in files: - this_file_tensors = torch.load(file) + map_location = "cpu" if not torch.cuda.is_available() else None + this_file_tensors = torch.load(file, map_location=map_location) for k in this_file_tensors: assert k not in tensors tensors.update(this_file_tensors)