Skip to content

Commit ae137e7

Browse files
haozha111copybara-github
authored andcommitted
Update documentation to reflect use of odml torch.
PiperOrigin-RevId: 706031730
1 parent b19fa25 commit ae137e7

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

ai_edge_torch/generative/doc/system_overview.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ The user journey of the Edge Generative API begins from the authoring library. T
4747
During our initial investigation, we found that although there are many open-source PyTorch implementation of common LLM models (e.g. Llama), it's pretty difficult to convert those models from source program to the TF Lite format, and achieve good performance. Those difficulties might include:
4848

4949
1) **Not able to be exported by Torch Dynamo**. The PyTorch 2.0 compiler uses Torch Dynamo under-the-hood to perform graph capturing (tracing a PyTorch `nn.Module` and convert it to a full graph representation with Aten operations). For an arbitrary PyTorch model, dynamo may fail during full graph export. Re-writing the model to be dynamo exportable usually takes some non-trivial work and may need a deep understanding of the compiler stack.
50-
2) **Not able to be converted by AI edge torch**. AI edge torch converter utilizes TorchXLA, StableHLO and TF Lite converter to convert the Aten graph to TF Lite model format. During this process, it may fail if the source FX graph has certain graph patterns / ops that are not supported by the converter. As the conversion will go through multiple stages, it's very difficult to triage and fix the conversion issue (either fixing the model itself or converter stack) for most users with little compiler knowledge.
50+
2) **Not able to be converted by AI edge torch**. AI edge torch converter utilizes ODML Torch, StableHLO and TF Lite converter to convert the Aten graph to TF Lite model format. During this process, it may fail if the source FX graph has certain graph patterns / ops that are not supported by the converter. As the conversion will go through multiple stages, it's very difficult to triage and fix the conversion issue (either fixing the model itself or converter stack) for most users with little compiler knowledge.
5151
3) **Difficult to achieve good OOB performance**. Even you may get lucky and successfully make it through the TF Lite model format, there is also no guarantee that the model will run performantly on device, and be able to leverage the on-device accelerators such as XNNPack, GPU or NPUs. For example, performance critical layers such as KV Cache and SDPA need to be handlded specifically in the converter and runtime to ensure it can be accelerated.
5252
4) **Applying quantization consistently is difficult**. It's difficult to specify quantization receipes in a consistent way for an arbitrary PyTorch generative AI model. For ODML, we need to carefully co-design our custom building blocks and AI Edge quantizer to ensure they work smoothly together, and is easy to configure with custom quantization receipes without noticiable model quality drop.
5353
5) **Many OSS models are not designed / optimized for mobile execution**. For example, those implementations may contain distributed training logic, CUDA dependencies, and tensor parallel optimizations for fast training. Models usually need to be rewritten to remove those pieces before exporting to mobile.
@@ -83,7 +83,7 @@ TODO: fill in this part.
8383

8484
With AI edge torch generative API, it adopts the same PyTorch to TF Lite conversion flow as traditional models. On a high level, it involves the following steps:
8585
1) PyTorch compiler (Dynamo). We leverage [Torch Dynamo](https://pytorch.org/docs/stable/torch.compiler_deepdive.html) which is the official graph compiler for PyTorch 2.0. It performs graph capturing and exports the PyTorch `nn.Module` to an FX graph with Aten operations.
86-
2) TorchXLA. For the moment, we are using [Torch/XLA](https://github.com/pytorch/xla) to compile the FX graph to Stable HLO graph. The converted graph will be stored as a Tensorflow SavedModel.
86+
2) ODML torch. For the moment, we are using [ODML torch](https://github.com/google-ai-edge/ai-edge-torch/tree/main/ai_edge_torch/odml_torch) to compile the FX graph to Stable HLO graph. The converted graph will be stored as a Tensorflow SavedModel.
8787
3) TF Lite MLIR converter. The TF Lite converter will consume the SavedModel with StableHLO ops, and further lower it down to TF Lite ops.
8888

8989
To convert a several billion parameters Generative AI model, it usually takes a lot of CPU and RAM resources on your machine, as the converter does all kinds of graph optimizations and transformations to ensure the converted model is highly optimized. Going forward, we will continue to improve the converter infrastructure to reduce its system overhead.
@@ -100,7 +100,7 @@ For code examples, please refer to [this](https://github.com/google-ai-edge/ai-e
100100

101101
### Composite op lowering via high-level function boundary
102102

103-
To address the performance issues of LLM, we identified that the model's forward pass is usually bottlenecked by a few key operations, such as KV cache or scaled dot product attention. As the converter traces the whole model and performs graph lowering through Aten/CoreAten/StableHLO etc, there are chances that certain op groups are not properly fused together, resulting in bad runtime performance. For example, we leverage PyTorch's [SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) to implement attention computation, however it will be converted to less efficient form when converting directly (which leads to unncessary FLOPs). The core problem here is that the converter doesn't know the boundary of a `SDPA` op so it can't apply any optimizations on it. We use TorchXLA's high-level function boundary API to mark the boundary of the `SDPA` operation, and then lower it to TF Lite custom op. For example, see how HLFB is applied inside the `SDPA` python implementation:
103+
To address the performance issues of LLM, we identified that the model's forward pass is usually bottlenecked by a few key operations, such as KV cache or scaled dot product attention. As the converter traces the whole model and performs graph lowering through Aten/CoreAten/StableHLO etc, there are chances that certain op groups are not properly fused together, resulting in bad runtime performance. For example, we leverage PyTorch's [SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) to implement attention computation, however it will be converted to less efficient form when converting directly (which leads to unncessary FLOPs). The core problem here is that the converter doesn't know the boundary of a `SDPA` op so it can't apply any optimizations on it. We use ODML torch's high-level function boundary API to mark the boundary of the `SDPA` operation, and then lower it to TF Lite custom op. For example, see how HLFB is applied inside the `SDPA` python implementation:
104104
https://github.com/google-ai-edge/ai-edge-torch/blob/9b06abe7c645f1d784804eb7f63f93458f358ba1/ai_edge_torch/generative/layers/scaled_dot_product_attention.py#L69-L117
105105

106106
As a result, everything in between the `builder.mark_inputs` and `builder.mark_outputs` call will be wrapped inside a StableHLO composite op named `odml.scaled_dot_product_attention`, and map to TF Lite's [SDPA custom op](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/genai/sdpa.cc) by the converter. During runtime, the delegate will match and replace the `odml.scaled_dot_product_attention` to the most efficient kernels on that delegate backend.

0 commit comments

Comments
 (0)