From 9422acadcd0e3cef14c75ddddf802656df146068 Mon Sep 17 00:00:00 2001 From: Jack Zhang Date: Mon, 3 Feb 2025 15:16:01 -0800 Subject: [PATCH] Buckify Llama multimodal export (#7604) Summary: Buckify export path for Llama multimodal Reviewed By: digantdesai, iseeyuan Differential Revision: D67949092 Pulled By: jackzhxng --- examples/models/llama/TARGETS | 1 + examples/models/llama3_2_vision/TARGETS | 14 ++++++ .../llama3_2_vision/text_decoder/TARGETS | 17 +++++++ .../llama3_2_vision/text_decoder/model.py | 14 ------ .../llama3_2_vision/vision_encoder/TARGETS | 17 +++++++ .../llama3_2_vision/vision_encoder/model.py | 2 + extension/llm/modules/TARGETS | 49 +++++++++++++++++++ extension/llm/modules/_position_embeddings.py | 2 + extension/llm/modules/attention.py | 2 + 9 files changed, 104 insertions(+), 14 deletions(-) create mode 100644 examples/models/llama3_2_vision/TARGETS create mode 100644 examples/models/llama3_2_vision/text_decoder/TARGETS create mode 100644 examples/models/llama3_2_vision/vision_encoder/TARGETS create mode 100644 extension/llm/modules/TARGETS diff --git a/examples/models/llama/TARGETS b/examples/models/llama/TARGETS index f6b78e876c8..0c9fcd31ad1 100644 --- a/examples/models/llama/TARGETS +++ b/examples/models/llama/TARGETS @@ -109,6 +109,7 @@ runtime.python_library( "//ai_codesign/gen_ai/fast_hadamard_transform:fast_hadamard_transform", "//caffe2:torch", "//executorch/backends/vulkan/_passes:vulkan_passes", + "//executorch/exir/passes:init_mutable_pass", "//executorch/examples/models:model_base", "//executorch/examples/models:models", "//executorch/exir/passes:init_mutable_pass", diff --git a/examples/models/llama3_2_vision/TARGETS b/examples/models/llama3_2_vision/TARGETS new file mode 100644 index 00000000000..133fd3a6839 --- /dev/null +++ b/examples/models/llama3_2_vision/TARGETS @@ -0,0 +1,14 @@ +load("@fbcode_macros//build_defs:python_library.bzl", "python_library") + +oncall("executorch") + +python_library( + name = "multimodal_lib", + srcs = [ + "__init__.py", + ], + deps = [ + "//executorch/examples/models/llama3_2_vision/text_decoder:model", + "//executorch/examples/models/llama3_2_vision/vision_encoder:model", + ], +) diff --git a/examples/models/llama3_2_vision/text_decoder/TARGETS b/examples/models/llama3_2_vision/text_decoder/TARGETS new file mode 100644 index 00000000000..e87b567f2bc --- /dev/null +++ b/examples/models/llama3_2_vision/text_decoder/TARGETS @@ -0,0 +1,17 @@ +load("@fbcode_macros//build_defs:python_library.bzl", "python_library") + +oncall("executorch") + +python_library( + name = "model", + srcs = [ + "model.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/examples/models:checkpoint", + "//pytorch/torchtune:lib", + "//executorch/extension/llm/modules:module_lib", + ], +) + \ No newline at end of file diff --git a/examples/models/llama3_2_vision/text_decoder/model.py b/examples/models/llama3_2_vision/text_decoder/model.py index 8cdbd8628a3..8f3a620affc 100644 --- a/examples/models/llama3_2_vision/text_decoder/model.py +++ b/examples/models/llama3_2_vision/text_decoder/model.py @@ -133,20 +133,6 @@ def __init__(self, **kwargs): print(unexpected) print("============= /unexpected ================") - # Prune the output layer if output_prune_map is provided. - output_prune_map = None - if self.output_prune_map_path is not None: - from executorch.examples.models.llama2.source_transformation.prune_output import ( - prune_output_vocab, - ) - - with open(self.output_prune_map_path, "r") as f: - output_prune_map = json.load(f) - # Change keys from string to int (json only supports string keys) - output_prune_map = {int(k): v for (k, v) in output_prune_map.items()} - - self.model_ = prune_output_vocab(self.model_, output_prune_map) - if self.use_kv_cache: print("Setting up KV cache on the model...") self.model_.setup_caches( diff --git a/examples/models/llama3_2_vision/vision_encoder/TARGETS b/examples/models/llama3_2_vision/vision_encoder/TARGETS new file mode 100644 index 00000000000..82717a56d08 --- /dev/null +++ b/examples/models/llama3_2_vision/vision_encoder/TARGETS @@ -0,0 +1,17 @@ +load("@fbcode_macros//build_defs:python_library.bzl", "python_library") + +oncall("executorch") + +python_library( + name = "model", + srcs = [ + "__init__.py", + "model.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/extension/llm/modules:module_lib", + "//pytorch/torchtune:lib", + "//executorch/examples/models:model_base", + ], +) diff --git a/examples/models/llama3_2_vision/vision_encoder/model.py b/examples/models/llama3_2_vision/vision_encoder/model.py index 79becd16205..7730d08ea08 100644 --- a/examples/models/llama3_2_vision/vision_encoder/model.py +++ b/examples/models/llama3_2_vision/vision_encoder/model.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-ignore-all-errors + from dataclasses import dataclass, field from typing import Optional diff --git a/extension/llm/modules/TARGETS b/extension/llm/modules/TARGETS new file mode 100644 index 00000000000..c0d2edf3818 --- /dev/null +++ b/extension/llm/modules/TARGETS @@ -0,0 +1,49 @@ +load("@fbcode_macros//build_defs:python_library.bzl", "python_library") + +oncall("executorch") + +python_library( + name = "kv_cache", + srcs = [ + "kv_cache.py", + ], + deps = [ + "//caffe2:torch", + "//pytorch/torchtune:lib", + ], +) + +python_library( + name = "attention", + srcs = [ + "attention.py", + ], + deps = [ + ":kv_cache", + "//caffe2:torch", + "//executorch/extension/llm/custom_ops:custom_ops", + "//pytorch/torchtune:lib", + ], +) + +python_library( + name = "position_embeddings", + srcs = [ + "_position_embeddings.py", + ], + deps = [ + "//caffe2:torch", + ], +) + +python_library( + name = "module_lib", + srcs = [ + "__init__.py", + ], + deps= [ + ":position_embeddings", + ":attention", + ":kv_cache", + ] +) diff --git a/extension/llm/modules/_position_embeddings.py b/extension/llm/modules/_position_embeddings.py index 3fd68a2184c..874019ec3dd 100644 --- a/extension/llm/modules/_position_embeddings.py +++ b/extension/llm/modules/_position_embeddings.py @@ -8,6 +8,8 @@ # Added torch._check() to make sure guards on symints are enforced. # See https://github.com/pytorch/torchtune/blob/main/torchtune/models/clip/_position_embeddings.py +# pyre-ignore-all-errors + import logging import math from typing import Any, Dict, Tuple diff --git a/extension/llm/modules/attention.py b/extension/llm/modules/attention.py index a138585e427..bb688c2b8c1 100644 --- a/extension/llm/modules/attention.py +++ b/extension/llm/modules/attention.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-ignore-all-errors + import logging from typing import Optional