Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Any, Callable
from ai_edge_torch import fx_infra
from ai_edge_torch import lowertools
from ai_edge_torch.odml_torch import optimization_barrier as optimization_barrier_lib
import torch
import torch.utils._pytree as pytree

Expand Down Expand Up @@ -276,6 +277,7 @@ def embedding(*args, **kwargs):
# Explicitly reshape back to the original shape. This places the ReshapeOp
# outside of the HLFB.
output = torch.reshape(output, (*(original_idx_shape), embedding_dim))
output, _ = optimization_barrier_lib.optimization_barrier(output, idx)
return output

node.target = embedding
Expand Down
51 changes: 43 additions & 8 deletions ai_edge_torch/generative/examples/gemma/gemma1.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import torch
from torch import nn

TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
TENSOR_NAMES_FUSED_QKV = loading_utils.ModelLoader.TensorNames(
ff_up_proj="model.layers.{}.mlp.up_proj",
ff_down_proj="model.layers.{}.mlp.down_proj",
ff_gate_proj="model.layers.{}.mlp.gate_proj",
Expand All @@ -36,6 +36,24 @@
lm_head=None,
)

TENSOR_NAMES_SEP_QKV = loading_utils.ModelLoader.TensorNames(
ff_up_proj="model.layers.{}.mlp.up_proj",
ff_down_proj="model.layers.{}.mlp.down_proj",
ff_gate_proj="model.layers.{}.mlp.gate_proj",
attn_query_proj="model.layers.{}.self_attn.q_proj",
attn_key_proj="model.layers.{}.self_attn.k_proj",
attn_value_proj="model.layers.{}.self_attn.v_proj",
attn_output_proj="model.layers.{}.self_attn.o_proj",
pre_attn_norm="model.layers.{}.input_layernorm",
post_attn_norm="model.layers.{}.post_attention_layernorm",
embedding="model.embed_tokens",
final_norm="model.norm",
)

TENSOR_NAMES_DICT = {
"safetensors": TENSOR_NAMES_SEP_QKV,
"kaggle": TENSOR_NAMES_FUSED_QKV,
}

class Gemma1(model_builder.DecoderOnlyModel):
"""A Gemma1 model built from the Edge Generative API layers."""
Expand Down Expand Up @@ -94,11 +112,28 @@ def build_2b_model(
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
mask_cache_size: int = 0,
) -> nn.Module:
return model_builder.build_decoder_only_model(
checkpoint_path=checkpoint_path,
config=get_model_config_2b(),
tensor_names=TENSOR_NAMES,
model_class=Gemma1,
custom_loader=custom_loader,
mask_cache_size=mask_cache_size,

# A list to store the reasons for each failure
key_errors = []

for tensor_names in TENSOR_NAMES_DICT.values():
try:
return model_builder.build_decoder_only_model(
checkpoint_path=checkpoint_path,
config=get_model_config_2b(),
tensor_names=tensor_names,
model_class=Gemma1,
custom_loader=custom_loader,
mask_cache_size=mask_cache_size,
)
except KeyError as ke:
# Store the specific key that was missing for later
key_errors.append(f"Missing key: {ke}")
continue

# If the loop finishes, raise an error with all the collected details
error_details = "\n".join(key_errors)
raise RuntimeError(
"Failed to build model after trying all configurations. "
f"Encountered the following errors:\n{error_details}"
)
95 changes: 95 additions & 0 deletions ai_edge_torch/generative/examples/smolvlm2/verify_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Verifies the reauthored SmolVLM2 Image Encoder model."""

import logging

from absl import app
from absl import flags
from ai_edge_torch.generative.examples.smolvlm2 import smolvlm2
from ai_edge_torch.generative.examples.smolvlm2 import vision_encoder
from PIL import Image
import requests
import torch
import transformers

_IMAGE_URL = flags.DEFINE_string(
"image_url",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true",
"The image URI to encode.",
)

_CHECKPOINT = flags.DEFINE_string(
"checkpoint",
"HuggingFaceTB/SmolVLM2-2.2B-Instruct",
"The checkpoint to verify.",
)

_REAUTHORTHED_CHECKPOINT = flags.DEFINE_string(
"pretrained_weights_path",
None,
"The path to the model's pretrained weights.",
)


def main(_):
checkpoint = _CHECKPOINT.value
logging.info("Loading the original model from: %s", checkpoint)
original_model = transformers.AutoModelForImageTextToText.from_pretrained(
checkpoint
)
original_model = original_model.eval().model

logging.info("Building the reauthored checkpoint from: %s", checkpoint)
reauthored_checkpoint = _REAUTHORTHED_CHECKPOINT.value
if reauthored_checkpoint is None:
raise ValueError("reauthored_checkpoint is required.")

logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
reauthored_model = vision_encoder.build_image_encoder(reauthored_checkpoint)

logging.info("Loading the tokenizer from: %s", checkpoint)
processor = transformers.AutoProcessor.from_pretrained(checkpoint)

logging.info("Loading the image from: %s", _IMAGE_URL.value)
image = Image.open(requests.get(_IMAGE_URL.value, stream=True).raw)
pixel_values = processor(images=image, return_tensors="pt")["pixel_values"]

logging.info("Forwarding the original model...")
outputs_original = original_model.get_image_features(pixel_values)
logging.info("outputs_original's shape: %s", outputs_original.shape)

pixel_values = pixel_values.reshape(
pixel_values.shape[0] * pixel_values.shape[1], *pixel_values.shape[2:]
)
logging.info("Forwarding the reauthored model...")
outputs_reauthored = reauthored_model.forward(
pixel_values=pixel_values
)
logging.info("outputs_reauthored's shape: %s", outputs_reauthored.shape)

try:
assert torch.allclose(
outputs_original, outputs_reauthored, atol=1e-03, rtol=1e-04
)
except AssertionError as e:
logging.error("*** FAILED *** verify with an image")
raise e
else:
logging.info("*** PASSED *** verify with an image")


if __name__ == "__main__":
app.run(main)
13 changes: 12 additions & 1 deletion ai_edge_torch/generative/examples/smolvlm2/vision_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,18 @@ def forward(
pixel_values: torch.Tensor,
export_config: export_cfg.ExportConfig = None,
) -> torch.Tensor:
x = self.siglip_encoder(pixel_values)
# Embed the image according to SiplipVisionEmbeddings.
x = self.siglip_encoder.tok_embedding(pixel_values)
x = x.flatten(2).transpose(1, 2)
x = x + self.siglip_encoder.tok_embedding_position

# Pass a dummy mask because SDPA attention impl expects non-None mask.
mask = torch.zeros(x.shape[0], 1, x.shape[1], x.shape[1])
for _, block in enumerate(self.siglip_encoder.transformer_blocks):
x = block(x, mask=mask)
x = self.siglip_encoder.final_norm(x)

# Project the image embeddings to text hidden size.
x = self.connector(x)
return x

Expand Down
61 changes: 61 additions & 0 deletions ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,11 @@ def _aten_rsqrt_decomp(x):
return torch.ops.tfl.rsqrt(x)


@register_decomp(torch.ops.aten.neg.default)
def _aten_neg_decomp(x):
return torch.ops.tfl.neg(x)


@register_decomp(torch.ops.aten.gelu.default)
def _aten_gelu_decomp(x, approximate="none"):
return torch.ops.tfl.gelu(x, approximate != "none")
Expand Down Expand Up @@ -317,6 +322,38 @@ def _aten_select_int_decomp(x, dim, index):
return torch.ops.tfl.squeeze(sliced, [dim])


@register_decomp(torch.ops.aten.slice.Tensor)
def _aten_slice_tensor_decomp(x, dim=0, start=None, end=None, step=1):
rank = x.dim()
dim_size = x.shape[dim]

# Initialize begin, end, strides for tfl.strided_slice
begin = [0] * rank
end_vec = list(x.shape)
strides = [1] * rank

# The logic below is to match PyTorch's `slice` behavior.
# `start` and `end` can be negative, which means they count from the end.
# `start=None` defaults to 0.
# `end=None` or a large number defaults to `dim_size` after clamping.

start_val = 0 if start is None else start
if start_val < 0:
start_val += dim_size

end_val = dim_size if end is None else end
if end_val < 0:
end_val += dim_size

# Clamp start and end to be within the dimension size, following PyTorch's
# logic.
start_val = max(0, min(start_val, dim_size))
end_val = max(start_val, min(end_val, dim_size))

begin[dim], end_vec[dim], strides[dim] = start_val, end_val, step
return torch.ops.tfl.strided_slice(x, begin, end_vec, strides)


@register_decomp(torch.ops.aten.where.self)
def _aten_where_self_decomp(condition, x, y):
x, y = _promote_types_for_binary_op(x, y)
Expand Down Expand Up @@ -351,3 +388,27 @@ def _aten__softmax_decomp(
softmax_result = torch.ops.tfl.softmax(x_permuted)
# Transpose the result back to the original dimensions.
return torch.ops.tfl.transpose(softmax_result, dims)


@register_decomp(torch.ops.aten.topk.default)
def _aten_topk_decomp(self, k, dim=-1, largest=True, sorted=True):
if not largest:
raise ValueError("Only largest=True is supported for torch.topk.")

if dim < 0:
dim = self.dim() + dim

if dim != self.dim() - 1:
self = torch.transpose(self, dim, -1)

# Ignores sorted value: tfl.topk_v2 only supports sorted=True, but it doesn't
# affect the correctness of the output.
out, indices = torch.ops.tfl.topk_v2(self, k)

if dim != self.dim() - 1:
out = torch.transpose(out, dim, -1)
indices = torch.transpose(indices, dim, -1)

# torch.topk returns int64 indices, but tfl.topk_v2 returns indices in int32.
indices = indices.to(torch.int64)
return out, indices
29 changes: 29 additions & 0 deletions ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,18 @@ def _tfl_rsqrt_lowering(
)


@lower(torch.ops.tfl.neg.default)
def _tfl_neg_lowering(
lctx: LoweringContext,
x: ir.Value,
) -> ir.Value:
return _ir_operation(
"tfl.neg",
results=lowering_utils.node_meta_to_ir_types(lctx.node),
operands=[x],
)


@lower(torch.ops.tfl.gelu.default)
def _tfl_gelu_lowering(
lctx: LoweringContext,
Expand Down Expand Up @@ -674,3 +686,20 @@ def _tfl_softmax_lowering(
"beta": ir.FloatAttr.get(ir.F32Type.get(), beta),
},
)


@lower(torch.ops.tfl.topk_v2.default)
def _tfl_topk_v2_lowering(
lctx: LoweringContext,
x: ir.Value,
k: int,
) -> tuple[ir.Value, ir.Value]:
return _ir_operation(
"tfl.topk_v2",
results=lowering_utils.node_meta_to_ir_types(lctx.node),
operands=[
x,
lowering_utils.numpy_array_constant(np.array(k, dtype=np.int32)),
],
attributes={},
)
12 changes: 12 additions & 0 deletions ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ def tfl_rsqrt(x: torch.Tensor) -> torch.Tensor:
return torch.rsqrt(x)


@custom_op_with_fake("tfl::neg")
def tfl_neg(x: torch.Tensor) -> torch.Tensor:
return torch.neg(x)


@custom_op_with_fake("tfl::gelu")
def tfl_gelu(x: torch.Tensor, approximate: bool = False) -> torch.Tensor:
gelu_approximate = "tanh" if approximate else "none"
Expand Down Expand Up @@ -292,6 +297,13 @@ def tfl_softmax(x: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.softmax(x, dim=-1)


@custom_op_with_fake("tfl::topk_v2")
def tfl_topk_v2(x: torch.Tensor, k: int) -> tuple[torch.Tensor, torch.Tensor]:
out, indices = torch.topk(x, k, dim=-1, largest=True, sorted=True)
indices = indices.to(torch.int32)
return out, indices


@custom_op_with_fake(
"tfl::slice", schema="(Tensor x, SymInt[] begin, SymInt[] size) -> Tensor"
)
Expand Down
Loading
Loading