Skip to content

Commit 6935caf

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Add BUCK files for llava python and C++ libs (#8297)
Summary: Add BUCK targets so it can be used in fbcode, bento etc. Reviewed By: luyich Differential Revision: D69278781
1 parent 77f18b2 commit 6935caf

File tree

4 files changed

+49
-2
lines changed

4 files changed

+49
-2
lines changed

examples/models/llava/export_llava.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def export(self) -> "LlavaEdgeManager":
6767
dynamic_shapes=dynamic_shape,
6868
strict=False,
6969
)
70+
# pyre-ignore: Incompatible attribute type [8]: Attribute `pre_autograd_graph_module` declared in class `LLMEdgeManager` has type `Optional[GraphModule]` but is used as type `Module`.
7071
self.pre_autograd_graph_module = self.export_program.module()
7172
return self
7273

examples/models/llava/image_util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
logging.basicConfig(level=logging.INFO, format=FORMAT)
2222

2323

24+
# pyre-ignore: Undefined or invalid type [11]: Annotation `Image` is not defined as a type.
2425
def prepare_image(image: Image, target_h: int, target_w: int) -> torch.Tensor:
2526
"""Read image into a tensor and resize the image so that it fits in
2627
a target_h x target_w canvas.

examples/models/llava/model.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from executorch.examples.models.llama.source_transformation.sdpa import (
2222
replace_sdpa_with_custom_op,
2323
)
24+
25+
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.examples.models.llava.image_util`.
2426
from executorch.examples.models.llava.image_util import prepare_image
2527
from executorch.examples.models.model_base import EagerModelBase
2628
from PIL import Image
@@ -48,6 +50,7 @@ def __init__(
4850
self.use_sdpa_with_kv_cache_op = use_sdpa_with_kv_cache_op
4951
self.model_ = llava_model
5052
self.image_processor = image_processor
53+
# pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_pt_objects.LlavaForConditionalGeneration` has no attribute `config`.
5154
self.vision_feature_layer = self.model_.config.vision_feature_layer
5255
self.vision_feature_select_strategy = (
5356
self.model_.config.vision_feature_select_strategy
@@ -76,6 +79,7 @@ def __init__(
7679
)
7780

7881
def _translate_state_dict_for_text_model(self) -> Dict[str, Any]:
82+
# pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_pt_objects.LlavaForConditionalGeneration` has no attribute `language_model`.
7983
state_dict = self.model_.language_model.state_dict()
8084
key_map = {
8185
# fmt: off
@@ -128,9 +132,11 @@ def get_model(self):
128132
return self.model_.get_model()
129133

130134
def embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
135+
# pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_pt_objects.LlavaForConditionalGeneration` has no attribute `language_model`.
131136
return self.model_.language_model.model.embed_tokens(tokens)
132137

133138
def encode_images(self, images: torch.Tensor) -> torch.Tensor:
139+
# pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_pt_objects.LlavaForConditionalGeneration` has no attribute `dtype`.
134140
images = images.to(dtype=self.model_.dtype)
135141
if type(images) is list:
136142
image_features = []
@@ -144,15 +150,19 @@ def encode_images(self, images: torch.Tensor) -> torch.Tensor:
144150
image_feature = self._feature_select(image_forward_out).to(image.dtype)
145151
image_features.append(image_feature)
146152
else:
153+
# pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_pt_objects.LlavaForConditionalGeneration` has no attribute `vision_tower`.
147154
image_forward_outs = self.model_.vision_tower(
155+
# pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_pt_objects.LlavaForConditionalGeneration` has no attribute `device`.
148156
images.to(device=self.model_.device, dtype=self.model_.dtype),
149157
output_hidden_states=True,
150158
)
151159
image_features = self._feature_select(image_forward_outs).to(images.dtype)
160+
# pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_pt_objects.LlavaForConditionalGeneration` has no attribute `multi_modal_projector`.
152161
image_features = self.model_.multi_modal_projector(image_features)
153162
return image_features
154163

155164
def image_preprocess(self, img: torch.Tensor) -> torch.Tensor:
165+
# pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_vision_objects.CLIPImageProcessor` has no attribute `crop_size`.
156166
target_h = self.image_processor.crop_size["height"]
157167
target_w = self.image_processor.crop_size["width"]
158168
# pad the image with median rgb value, to make a square
@@ -195,10 +205,15 @@ def image_preprocess(self, img: torch.Tensor) -> torch.Tensor:
195205
# print(resized.shape)
196206
# cropped = F.center_crop(img, output_size=[w, w])
197207
# print(cropped.shape)
208+
# pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_vision_objects.CLIPImageProcessor` has no attribute `rescale_factor`.
198209
scaled = resized * self.image_processor.rescale_factor
199210
# print(scaled)
200211
normed = F.normalize(
201-
scaled, self.image_processor.image_mean, self.image_processor.image_std
212+
scaled,
213+
# pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_vision_objects.CLIPImageProcessor` has no attribute `image_mean`.
214+
self.image_processor.image_mean,
215+
# pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_vision_objects.CLIPImageProcessor` has no attribute `image_std`.
216+
self.image_processor.image_std,
202217
)
203218
# print(normed)
204219
return normed.unsqueeze(0)
@@ -249,7 +264,9 @@ def prefill_ref(
249264
) -> torch.Tensor:
250265
"""Avoiding the torch.where() call to find <image> placeholder and insert image embedding. Taking 3 inputs instead."""
251266
embeds = self.prefill_embedding(prompt_before_image, images, prompt_after_image)
267+
# pyre-ignore: Undefined attribute [16]: Module `transformers` has no attribute `LlamaForCausalLM`.
252268
return LlamaForCausalLM.forward(
269+
# pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_pt_objects.LlavaForConditionalGeneration` has no attribute `language_model`.
253270
self.model_.language_model,
254271
inputs_embeds=embeds,
255272
return_dict=False,
@@ -268,12 +285,16 @@ class LlavaModel(EagerModelBase):
268285
def __init__(self, use_sdpa_with_kv_cache_op=True, max_seq_len=768):
269286
self.use_sdpa_with_kv_cache_op = use_sdpa_with_kv_cache_op
270287
self.max_seq_len = max_seq_len
271-
self.processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
288+
self.processor = AutoProcessor.from_pretrained(
289+
"llava-hf/llava-1.5-7b-hf",
290+
revision="a272c74b2481d8aff3aa6fc2c4bf891fe57334fb", # Need this for transformers >= 4.44.2
291+
)
272292
self.tokenizer = self.processor.tokenizer
273293
self.image_processor = self.processor.image_processor
274294
self.model = LlavaForConditionalGeneration.from_pretrained(
275295
"llava-hf/llava-1.5-7b-hf",
276296
device_map="cpu",
297+
revision="a272c74b2481d8aff3aa6fc2c4bf891fe57334fb", # Need this for transformers >= 4.44.2
277298
)
278299
self.image = Image.open(
279300
requests.get(

examples/models/llava/targets.bzl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "get_oss_build_kwargs", "runtime")
2+
3+
def define_common_targets():
4+
runtime.cxx_binary(
5+
name = "main",
6+
srcs = [
7+
"main.cpp",
8+
],
9+
compiler_flags = ["-Wno-global-constructors"],
10+
preprocessor_flags = [
11+
"-DET_USE_THREADPOOL",
12+
],
13+
deps = [
14+
"//executorch/examples/models/llava/runner:runner",
15+
"//executorch/extension/evalue_util:print_evalue",
16+
"//executorch/extension/threadpool:cpuinfo_utils",
17+
"//executorch/extension/threadpool:threadpool",
18+
],
19+
external_deps = [
20+
"gflags",
21+
"torch-core-cpp",
22+
],
23+
**get_oss_build_kwargs()
24+
)

0 commit comments

Comments
 (0)