File tree Expand file tree Collapse file tree 9 files changed +104
-14
lines changed Expand file tree Collapse file tree 9 files changed +104
-14
lines changed Original file line number Diff line number Diff line change @@ -109,6 +109,7 @@ runtime.python_library(
109109 "//ai_codesign/gen_ai/fast_hadamard_transform:fast_hadamard_transform",
110110 "//caffe2:torch",
111111 "//executorch/backends/vulkan/_passes:vulkan_passes",
112+ "//executorch/exir/passes:init_mutable_pass",
112113 "//executorch/examples/models:model_base",
113114 "//executorch/examples/models:models",
114115 "//executorch/exir/passes:init_mutable_pass",
Original file line number Diff line number Diff line change 1+ load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
2+
3+ oncall("executorch")
4+
5+ python_library(
6+ name = "multimodal_lib",
7+ srcs = [
8+ "__init__.py",
9+ ],
10+ deps = [
11+ "//executorch/examples/models/llama3_2_vision/text_decoder:model",
12+ "//executorch/examples/models/llama3_2_vision/vision_encoder:model",
13+ ],
14+ )
Original file line number Diff line number Diff line change 1+ load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
2+
3+ oncall("executorch")
4+
5+ python_library(
6+ name = "model",
7+ srcs = [
8+ "model.py",
9+ ],
10+ deps = [
11+ "//caffe2:torch",
12+ "//executorch/examples/models:checkpoint",
13+ "//pytorch/torchtune:lib",
14+ "//executorch/extension/llm/modules:module_lib",
15+ ],
16+ )
17+
Original file line number Diff line number Diff line change @@ -133,20 +133,6 @@ def __init__(self, **kwargs):
133133 print (unexpected )
134134 print ("============= /unexpected ================" )
135135
136- # Prune the output layer if output_prune_map is provided.
137- output_prune_map = None
138- if self .output_prune_map_path is not None :
139- from executorch .examples .models .llama2 .source_transformation .prune_output import (
140- prune_output_vocab ,
141- )
142-
143- with open (self .output_prune_map_path , "r" ) as f :
144- output_prune_map = json .load (f )
145- # Change keys from string to int (json only supports string keys)
146- output_prune_map = {int (k ): v for (k , v ) in output_prune_map .items ()}
147-
148- self .model_ = prune_output_vocab (self .model_ , output_prune_map )
149-
150136 if self .use_kv_cache :
151137 print ("Setting up KV cache on the model..." )
152138 self .model_ .setup_caches (
Original file line number Diff line number Diff line change 1+ load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
2+
3+ oncall("executorch")
4+
5+ python_library(
6+ name = "model",
7+ srcs = [
8+ "__init__.py",
9+ "model.py",
10+ ],
11+ deps = [
12+ "//caffe2:torch",
13+ "//executorch/extension/llm/modules:module_lib",
14+ "//pytorch/torchtune:lib",
15+ "//executorch/examples/models:model_base",
16+ ],
17+ )
Original file line number Diff line number Diff line change 44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ # pyre-ignore-all-errors
8+
79from dataclasses import dataclass , field
810from typing import Optional
911
Original file line number Diff line number Diff line change 1+ load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
2+
3+ oncall("executorch")
4+
5+ python_library(
6+ name = "kv_cache",
7+ srcs = [
8+ "kv_cache.py",
9+ ],
10+ deps = [
11+ "//caffe2:torch",
12+ "//pytorch/torchtune:lib",
13+ ],
14+ )
15+
16+ python_library(
17+ name = "attention",
18+ srcs = [
19+ "attention.py",
20+ ],
21+ deps = [
22+ ":kv_cache",
23+ "//caffe2:torch",
24+ "//executorch/extension/llm/custom_ops:custom_ops",
25+ "//pytorch/torchtune:lib",
26+ ],
27+ )
28+
29+ python_library(
30+ name = "position_embeddings",
31+ srcs = [
32+ "_position_embeddings.py",
33+ ],
34+ deps = [
35+ "//caffe2:torch",
36+ ],
37+ )
38+
39+ python_library(
40+ name = "module_lib",
41+ srcs = [
42+ "__init__.py",
43+ ],
44+ deps= [
45+ ":position_embeddings",
46+ ":attention",
47+ ":kv_cache",
48+ ]
49+ )
Original file line number Diff line number Diff line change 88# Added torch._check() to make sure guards on symints are enforced.
99# See https://github.com/pytorch/torchtune/blob/main/torchtune/models/clip/_position_embeddings.py
1010
11+ # pyre-ignore-all-errors
12+
1113import logging
1214import math
1315from typing import Any , Dict , Tuple
Original file line number Diff line number Diff line change 44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ # pyre-ignore-all-errors
8+
79import logging
810from typing import Optional
911
You can’t perform that action at this time.
0 commit comments