Skip to content

Commit 2c7b7e8

Browse files
authored
[llava] Enable dynamic shape for image preprocessor
Differential Revision: D61818152 Pull Request resolved: #4821
1 parent b284866 commit 2c7b7e8

File tree

5 files changed

+142
-32
lines changed

5 files changed

+142
-32
lines changed

.ci/scripts/test_llava.sh

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,13 @@ export_llava() {
5454
$PYTHON_EXECUTABLE -m executorch.examples.models.llava.export_llava --pte-name llava.pte --with-artifacts
5555
}
5656

57+
# Download a new image with different size, to test if the model can handle different image sizes
58+
prepare_image_tensor() {
59+
echo "Downloading image"
60+
curl -o basketball.jpg https://upload.wikimedia.org/wikipedia/commons/7/73/Chicago_Bulls_and_New_Jersey_Nets%2C_March_28%2C_1991.jpg
61+
$PYTHON_EXECUTABLE -m executorch.examples.models.llava.image_util --image-path basketball.jpg --output-path image.pt
62+
}
63+
5764
run_and_verify() {
5865
NOW=$(date +"%H:%M:%S")
5966
echo "Starting to run llava runner at ${NOW}"
@@ -79,7 +86,12 @@ run_and_verify() {
7986
# verify result.txt
8087
RESULT=$(cat result.txt)
8188
# set the expected prefix to be the same as prompt because there's a bug in sdpa_with_kv_cache that causes <unk> tokens.
82-
EXPECTED_PREFIX="ASSISTANT:"
89+
if [[ "$(uname)" == "Darwin" ]]; then
90+
EXPECTED_PREFIX="ASSISTANT: image captures a basketball game in progress on a basketball court. There are several players on the court, with one player in the foreground holding a basketball, and"
91+
else
92+
# set the expected prefix to be the same as prompt because there's a bug in sdpa_with_kv_cache that causes <unk> tokens.
93+
EXPECTED_PREFIX="ASSISTANT:"
94+
fi
8395
if [[ "${RESULT}" == *"${EXPECTED_PREFIX}"* ]]; then
8496
echo "Expected result prefix: ${EXPECTED_PREFIX}"
8597
echo "Actual result: ${RESULT}"
@@ -96,4 +108,5 @@ run_and_verify() {
96108
cmake_install_executorch_libraries
97109
cmake_build_llava_runner
98110
export_llava
111+
prepare_image_tensor
99112
run_and_verify

examples/models/llava/export_llava.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from executorch.examples.models.llama2.source_transformation.sdpa import (
2323
replace_sdpa_with_custom_op,
2424
)
25+
from executorch.examples.models.llava.image_util import serialize_image
2526
from executorch.examples.models.llava.model import LlavaModel
2627
from executorch.exir import (
2728
EdgeCompileConfig,
@@ -35,7 +36,6 @@
3536

3637
from executorch.extension.llm.export.builder import DType, LLMEdgeManager
3738
from executorch.extension.llm.tokenizer.tokenizer import Tokenizer
38-
from torch import nn
3939
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
4040
get_symmetric_quantization_config,
4141
XNNPACKQuantizer,
@@ -231,14 +231,7 @@ def get_image_tensor_for_llava_runner(llava_model):
231231
# llava runner doesn't have image reader so an image tensor is needed.
232232
(resized,) = llava_model.get_example_inputs()
233233

234-
copy = torch.tensor(resized)
235-
m = nn.Module()
236-
par = nn.Parameter(copy, requires_grad=False)
237-
m.register_parameter("0", par)
238-
tensors = torch.jit.script(m)
239-
tensors.save("image.pt")
240-
241-
logging.info("Saved image tensor to image.pt")
234+
serialize_image(resized, "image.pt")
242235

243236

244237
def get_tokenizer_for_llava_runner(llava_model):
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Utility functions for image processing. Run it with your image:
8+
9+
# python image_util.py --image-path <path_to_image>
10+
11+
import logging
12+
from argparse import ArgumentParser
13+
14+
import torch
15+
import torchvision
16+
from PIL import Image
17+
from torch import nn
18+
19+
20+
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
21+
logging.basicConfig(level=logging.INFO, format=FORMAT)
22+
23+
24+
def prepare_image(image: Image, target_h: int, target_w: int) -> torch.Tensor:
25+
"""Read image into a tensor and resize the image so that it fits in
26+
a target_h x target_w canvas.
27+
28+
Args:
29+
image (Image): An Image object.
30+
target_h (int): Target height.
31+
target_w (int): Target width.
32+
33+
Returns:
34+
torch.Tensor: resized image tensor.
35+
"""
36+
img = torchvision.transforms.functional.pil_to_tensor(image)
37+
# height ratio
38+
ratio_h = img.shape[1] / target_h
39+
# width ratio
40+
ratio_w = img.shape[2] / target_w
41+
# resize the image so that it fits in a target_h x target_w canvas
42+
ratio = max(ratio_h, ratio_w)
43+
output_size = (int(img.shape[1] / ratio), int(img.shape[2] / ratio))
44+
img = torchvision.transforms.Resize(size=output_size)(img)
45+
return img
46+
47+
48+
def serialize_image(image: torch.Tensor, path: str) -> None:
49+
copy = torch.tensor(image)
50+
m = nn.Module()
51+
par = nn.Parameter(copy, requires_grad=False)
52+
m.register_parameter("0", par)
53+
tensors = torch.jit.script(m)
54+
tensors.save(path)
55+
56+
logging.info(f"Saved image tensor to {path}")
57+
58+
59+
def main():
60+
parser = ArgumentParser()
61+
parser.add_argument(
62+
"--image-path",
63+
required=True,
64+
help="Path to the image.",
65+
)
66+
parser.add_argument(
67+
"--output-path",
68+
default="image.pt",
69+
)
70+
args = parser.parse_args()
71+
72+
image = Image.open(args.image_path)
73+
image_tensor = prepare_image(image, target_h=336, target_w=336)
74+
serialize_image(image_tensor, args.output_path)
75+
76+
77+
if __name__ == "__main__":
78+
main()

examples/models/llava/model.py

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,18 @@
66

77
# An ExecuTorch friendly implementation of Llava-1.5.
88

9-
import math
10-
119
import re
1210

1311
from typing import Any, Dict, Optional
1412

1513
import requests
1614
import torch
17-
import torchvision
1815
from executorch.examples.models.llama2.llama_transformer import ModelArgs, Transformer
1916

2017
from executorch.examples.models.llama2.source_transformation.sdpa import (
2118
replace_sdpa_with_custom_op,
2219
)
20+
from executorch.examples.models.llava.image_util import prepare_image
2321
from executorch.examples.models.model_base import EagerModelBase
2422
from PIL import Image
2523

@@ -156,19 +154,32 @@ def encode_images(self, images: torch.Tensor) -> torch.Tensor:
156154
return image_features
157155

158156
def image_preprocess(self, img: torch.Tensor) -> torch.Tensor:
159-
w = max(img.shape[1], img.shape[2])
157+
target_h = self.image_processor.crop_size["height"]
158+
target_w = self.image_processor.crop_size["width"]
160159
# pad the image with median rgb value, to make a square
161-
v_padding = (w - img.shape[1]) / 2
162-
h_padding = (w - img.shape[2]) / 2
163-
l_pad = int(math.ceil(h_padding))
164-
t_pad = int(math.ceil(v_padding))
165-
r_pad = int(math.floor(h_padding))
166-
b_pad = int(math.floor(v_padding))
167-
resized = F.pad(
160+
l_pad = (target_w - img.shape[2]) // 2
161+
t_pad = (target_h - img.shape[1]) // 2
162+
# ceil division
163+
r_pad = -((target_w - img.shape[2]) // -2)
164+
b_pad = -((target_h - img.shape[1]) // -2)
165+
166+
torch._check(l_pad >= 0)
167+
torch._check(t_pad >= 0)
168+
torch._check(r_pad >= 0)
169+
torch._check(b_pad >= 0)
170+
171+
# This is different from the original implementation, due to export limitations.
172+
resized = torch.nn.functional.pad(
168173
img,
169-
padding=(l_pad, t_pad, r_pad, b_pad),
170-
fill=tuple(int(x * 255) for x in self.image_processor.image_mean),
174+
(l_pad, r_pad, t_pad, b_pad),
171175
)
176+
# originally:
177+
# resized = F.pad(
178+
# img,
179+
# padding=(l_pad, t_pad, r_pad, b_pad),
180+
# fill=tuple(int(x * 255) for x in self.image_mean),
181+
# )
182+
172183
# TODO: implement _upsample_bicubic_aa.out in portable kernel library.
173184
# here padded shape should be max(h, w) x max(h, w)
174185
# skipping resize for now due to missing _upsample_bicubic_aa kernel in portable
@@ -287,13 +298,12 @@ def get_example_inputs(self):
287298
"""Returns a resized image as input to model.forward()."""
288299
if self.resized_image:
289300
return self.resized_image
290-
imagr = torchvision.transforms.functional.pil_to_tensor(self.image)
291-
ratio = (
292-
max(imagr.shape[1], imagr.shape[2])
293-
/ self.image_processor.crop_size["height"]
301+
resized = prepare_image(
302+
self.image,
303+
self.image_processor.crop_size["height"],
304+
self.image_processor.crop_size["width"],
294305
)
295-
output_size = (int(imagr.shape[1] / ratio), int(imagr.shape[2] / ratio))
296-
self.resized_image = (torchvision.transforms.Resize(size=output_size)(imagr),)
306+
self.resized_image = (resized,)
297307
return self.resized_image
298308

299309
def get_inputs_for_prefill(self):
@@ -317,8 +327,13 @@ def get_dynamic_shapes(self):
317327
return self._get_image_dynamic_shapes()
318328

319329
def _get_image_dynamic_shapes(self):
320-
height = Dim("height", min=8, max=336)
321-
width = Dim("width", min=28, max=336)
330+
# only support even number of height and width for now
331+
_height = Dim(
332+
"_height", min=1, max=self.image_processor.crop_size["height"] // 2
333+
)
334+
_width = Dim("_width", min=1, max=self.image_processor.crop_size["width"] // 2)
335+
height = 2 * _height
336+
width = 2 * _width
322337
dynamic_shapes = [{1: height, 2: width}]
323338
return dynamic_shapes
324339

examples/models/llava/test/test_pte.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
import sys
99

1010
import torch
11-
11+
from executorch.examples.models.llava.image_util import prepare_image
1212
from executorch.examples.models.llava.model import LlavaModel
1313
from executorch.extension.pybindings.portable_lib import _load_for_executorch
14+
from PIL import Image
1415

1516
# Custom ops has to be loaded after portable_lib.
1617
# I don't know how to stop UFMT so I'm just using if True: to avoid lint error
@@ -24,13 +25,23 @@
2425

2526
def main():
2627
args = sys.argv[1:]
28+
if len(args) == 0:
29+
print(
30+
"Usage: python test_pte.py <model_path> <image_path?>. If no image, will use default image."
31+
)
32+
sys.exit(1)
33+
2734
llava_module = _load_for_executorch(args[0])
2835

2936
llava_model = LlavaModel()
3037

3138
prompt_before_image, resized, prompt_after_image = (
3239
llava_model.get_inputs_for_prefill()
3340
)
41+
if len(args) == 2:
42+
image_path = args[1]
43+
image = Image.open(image_path)
44+
resized = prepare_image(image, target_h=336, target_w=336)
3445

3546
start_pos = 0
3647
# pte prefill prompt before img

0 commit comments

Comments
 (0)