6
6
7
7
# An ExecuTorch friendly implementation of Llava-1.5.
8
8
9
- import math
10
-
11
9
import re
12
10
13
11
from typing import Any , Dict , Optional
14
12
15
13
import requests
16
14
import torch
17
- import torchvision
18
15
from executorch .examples .models .llama2 .llama_transformer import ModelArgs , Transformer
19
16
20
17
from executorch .examples .models .llama2 .source_transformation .sdpa import (
21
18
replace_sdpa_with_custom_op ,
22
19
)
20
+ from executorch .examples .models .llava .image_util import prepare_image
23
21
from executorch .examples .models .model_base import EagerModelBase
24
22
from PIL import Image
25
23
@@ -156,19 +154,32 @@ def encode_images(self, images: torch.Tensor) -> torch.Tensor:
156
154
return image_features
157
155
158
156
def image_preprocess (self , img : torch .Tensor ) -> torch .Tensor :
159
- w = max (img .shape [1 ], img .shape [2 ])
157
+ target_h = self .image_processor .crop_size ["height" ]
158
+ target_w = self .image_processor .crop_size ["width" ]
160
159
# pad the image with median rgb value, to make a square
161
- v_padding = (w - img .shape [1 ]) / 2
162
- h_padding = (w - img .shape [2 ]) / 2
163
- l_pad = int (math .ceil (h_padding ))
164
- t_pad = int (math .ceil (v_padding ))
165
- r_pad = int (math .floor (h_padding ))
166
- b_pad = int (math .floor (v_padding ))
167
- resized = F .pad (
160
+ l_pad = (target_w - img .shape [2 ]) // 2
161
+ t_pad = (target_h - img .shape [1 ]) // 2
162
+ # ceil division
163
+ r_pad = - ((target_w - img .shape [2 ]) // - 2 )
164
+ b_pad = - ((target_h - img .shape [1 ]) // - 2 )
165
+
166
+ torch ._check (l_pad >= 0 )
167
+ torch ._check (t_pad >= 0 )
168
+ torch ._check (r_pad >= 0 )
169
+ torch ._check (b_pad >= 0 )
170
+
171
+ # This is different from the original implementation, due to export limitations.
172
+ resized = torch .nn .functional .pad (
168
173
img ,
169
- padding = (l_pad , t_pad , r_pad , b_pad ),
170
- fill = tuple (int (x * 255 ) for x in self .image_processor .image_mean ),
174
+ (l_pad , r_pad , t_pad , b_pad ),
171
175
)
176
+ # originally:
177
+ # resized = F.pad(
178
+ # img,
179
+ # padding=(l_pad, t_pad, r_pad, b_pad),
180
+ # fill=tuple(int(x * 255) for x in self.image_mean),
181
+ # )
182
+
172
183
# TODO: implement _upsample_bicubic_aa.out in portable kernel library.
173
184
# here padded shape should be max(h, w) x max(h, w)
174
185
# skipping resize for now due to missing _upsample_bicubic_aa kernel in portable
@@ -287,13 +298,12 @@ def get_example_inputs(self):
287
298
"""Returns a resized image as input to model.forward()."""
288
299
if self .resized_image :
289
300
return self .resized_image
290
- imagr = torchvision . transforms . functional . pil_to_tensor ( self . image )
291
- ratio = (
292
- max ( imagr . shape [ 1 ], imagr . shape [ 2 ])
293
- / self .image_processor .crop_size ["height" ]
301
+ resized = prepare_image (
302
+ self . image ,
303
+ self . image_processor . crop_size [ "height" ],
304
+ self .image_processor .crop_size ["width" ],
294
305
)
295
- output_size = (int (imagr .shape [1 ] / ratio ), int (imagr .shape [2 ] / ratio ))
296
- self .resized_image = (torchvision .transforms .Resize (size = output_size )(imagr ),)
306
+ self .resized_image = (resized ,)
297
307
return self .resized_image
298
308
299
309
def get_inputs_for_prefill (self ):
@@ -317,8 +327,13 @@ def get_dynamic_shapes(self):
317
327
return self ._get_image_dynamic_shapes ()
318
328
319
329
def _get_image_dynamic_shapes (self ):
320
- height = Dim ("height" , min = 8 , max = 336 )
321
- width = Dim ("width" , min = 28 , max = 336 )
330
+ # only support even number of height and width for now
331
+ _height = Dim (
332
+ "_height" , min = 1 , max = self .image_processor .crop_size ["height" ] // 2
333
+ )
334
+ _width = Dim ("_width" , min = 1 , max = self .image_processor .crop_size ["width" ] // 2 )
335
+ height = 2 * _height
336
+ width = 2 * _width
322
337
dynamic_shapes = [{1 : height , 2 : width }]
323
338
return dynamic_shapes
324
339
0 commit comments