You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: ai_edge_torch/generative/doc/system_overview.md
+3-3Lines changed: 3 additions & 3 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -47,7 +47,7 @@ The user journey of the Edge Generative API begins from the authoring library. T
47
47
During our initial investigation, we found that although there are many open-source PyTorch implementation of common LLM models (e.g. Llama), it's pretty difficult to convert those models from source program to the TF Lite format, and achieve good performance. Those difficulties might include:
48
48
49
49
1)**Not able to be exported by Torch Dynamo**. The PyTorch 2.0 compiler uses Torch Dynamo under-the-hood to perform graph capturing (tracing a PyTorch `nn.Module` and convert it to a full graph representation with Aten operations). For an arbitrary PyTorch model, dynamo may fail during full graph export. Re-writing the model to be dynamo exportable usually takes some non-trivial work and may need a deep understanding of the compiler stack.
50
-
2)**Not able to be converted by AI edge torch**. AI edge torch converter utilizes TorchXLA, StableHLO and TF Lite converter to convert the Aten graph to TF Lite model format. During this process, it may fail if the source FX graph has certain graph patterns / ops that are not supported by the converter. As the conversion will go through multiple stages, it's very difficult to triage and fix the conversion issue (either fixing the model itself or converter stack) for most users with little compiler knowledge.
50
+
2)**Not able to be converted by AI edge torch**. AI edge torch converter utilizes ODML Torch, StableHLO and TF Lite converter to convert the Aten graph to TF Lite model format. During this process, it may fail if the source FX graph has certain graph patterns / ops that are not supported by the converter. As the conversion will go through multiple stages, it's very difficult to triage and fix the conversion issue (either fixing the model itself or converter stack) for most users with little compiler knowledge.
51
51
3)**Difficult to achieve good OOB performance**. Even you may get lucky and successfully make it through the TF Lite model format, there is also no guarantee that the model will run performantly on device, and be able to leverage the on-device accelerators such as XNNPack, GPU or NPUs. For example, performance critical layers such as KV Cache and SDPA need to be handlded specifically in the converter and runtime to ensure it can be accelerated.
52
52
4)**Applying quantization consistently is difficult**. It's difficult to specify quantization receipes in a consistent way for an arbitrary PyTorch generative AI model. For ODML, we need to carefully co-design our custom building blocks and AI Edge quantizer to ensure they work smoothly together, and is easy to configure with custom quantization receipes without noticiable model quality drop.
53
53
5)**Many OSS models are not designed / optimized for mobile execution**. For example, those implementations may contain distributed training logic, CUDA dependencies, and tensor parallel optimizations for fast training. Models usually need to be rewritten to remove those pieces before exporting to mobile.
@@ -83,7 +83,7 @@ TODO: fill in this part.
83
83
84
84
With AI edge torch generative API, it adopts the same PyTorch to TF Lite conversion flow as traditional models. On a high level, it involves the following steps:
85
85
1) PyTorch compiler (Dynamo). We leverage [Torch Dynamo](https://pytorch.org/docs/stable/torch.compiler_deepdive.html) which is the official graph compiler for PyTorch 2.0. It performs graph capturing and exports the PyTorch `nn.Module` to an FX graph with Aten operations.
86
-
2)TorchXLA. For the moment, we are using [Torch/XLA](https://github.com/pytorch/xla) to compile the FX graph to Stable HLO graph. The converted graph will be stored as a Tensorflow SavedModel.
86
+
2)ODML torch. For the moment, we are using [ODML torch](https://github.com/google-ai-edge/ai-edge-torch/tree/main/ai_edge_torch/odml_torch) to compile the FX graph to Stable HLO graph. The converted graph will be stored as a Tensorflow SavedModel.
87
87
3) TF Lite MLIR converter. The TF Lite converter will consume the SavedModel with StableHLO ops, and further lower it down to TF Lite ops.
88
88
89
89
To convert a several billion parameters Generative AI model, it usually takes a lot of CPU and RAM resources on your machine, as the converter does all kinds of graph optimizations and transformations to ensure the converted model is highly optimized. Going forward, we will continue to improve the converter infrastructure to reduce its system overhead.
@@ -100,7 +100,7 @@ For code examples, please refer to [this](https://github.com/google-ai-edge/ai-e
100
100
101
101
### Composite op lowering via high-level function boundary
102
102
103
-
To address the performance issues of LLM, we identified that the model's forward pass is usually bottlenecked by a few key operations, such as KV cache or scaled dot product attention. As the converter traces the whole model and performs graph lowering through Aten/CoreAten/StableHLO etc, there are chances that certain op groups are not properly fused together, resulting in bad runtime performance. For example, we leverage PyTorch's [SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) to implement attention computation, however it will be converted to less efficient form when converting directly (which leads to unncessary FLOPs). The core problem here is that the converter doesn't know the boundary of a `SDPA` op so it can't apply any optimizations on it. We use TorchXLA's high-level function boundary API to mark the boundary of the `SDPA` operation, and then lower it to TF Lite custom op. For example, see how HLFB is applied inside the `SDPA` python implementation:
103
+
To address the performance issues of LLM, we identified that the model's forward pass is usually bottlenecked by a few key operations, such as KV cache or scaled dot product attention. As the converter traces the whole model and performs graph lowering through Aten/CoreAten/StableHLO etc, there are chances that certain op groups are not properly fused together, resulting in bad runtime performance. For example, we leverage PyTorch's [SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) to implement attention computation, however it will be converted to less efficient form when converting directly (which leads to unncessary FLOPs). The core problem here is that the converter doesn't know the boundary of a `SDPA` op so it can't apply any optimizations on it. We use ODML torch's high-level function boundary API to mark the boundary of the `SDPA` operation, and then lower it to TF Lite custom op. For example, see how HLFB is applied inside the `SDPA` python implementation:
As a result, everything in between the `builder.mark_inputs` and `builder.mark_outputs` call will be wrapped inside a StableHLO composite op named `odml.scaled_dot_product_attention`, and map to TF Lite's [SDPA custom op](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/genai/sdpa.cc) by the converter. During runtime, the delegate will match and replace the `odml.scaled_dot_product_attention` to the most efficient kernels on that delegate backend.
0 commit comments