2121from executorch .examples .models .llama .source_transformation .sdpa import (
2222 replace_sdpa_with_custom_op ,
2323)
24+
25+ # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.examples.models.llava.image_util`.
2426from executorch .examples .models .llava .image_util import prepare_image
2527from executorch .examples .models .model_base import EagerModelBase
2628from PIL import Image
@@ -48,6 +50,7 @@ def __init__(
4850 self .use_sdpa_with_kv_cache_op = use_sdpa_with_kv_cache_op
4951 self .model_ = llava_model
5052 self .image_processor = image_processor
53+ # pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_pt_objects.LlavaForConditionalGeneration` has no attribute `config`.
5154 self .vision_feature_layer = self .model_ .config .vision_feature_layer
5255 self .vision_feature_select_strategy = (
5356 self .model_ .config .vision_feature_select_strategy
@@ -76,6 +79,7 @@ def __init__(
7679 )
7780
7881 def _translate_state_dict_for_text_model (self ) -> Dict [str , Any ]:
82+ # pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_pt_objects.LlavaForConditionalGeneration` has no attribute `language_model`.
7983 state_dict = self .model_ .language_model .state_dict ()
8084 key_map = {
8185 # fmt: off
@@ -128,9 +132,11 @@ def get_model(self):
128132 return self .model_ .get_model ()
129133
130134 def embed_tokens (self , tokens : torch .Tensor ) -> torch .Tensor :
135+ # pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_pt_objects.LlavaForConditionalGeneration` has no attribute `language_model`.
131136 return self .model_ .language_model .model .embed_tokens (tokens )
132137
133138 def encode_images (self , images : torch .Tensor ) -> torch .Tensor :
139+ # pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_pt_objects.LlavaForConditionalGeneration` has no attribute `dtype`.
134140 images = images .to (dtype = self .model_ .dtype )
135141 if type (images ) is list :
136142 image_features = []
@@ -144,15 +150,19 @@ def encode_images(self, images: torch.Tensor) -> torch.Tensor:
144150 image_feature = self ._feature_select (image_forward_out ).to (image .dtype )
145151 image_features .append (image_feature )
146152 else :
153+ # pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_pt_objects.LlavaForConditionalGeneration` has no attribute `vision_tower`.
147154 image_forward_outs = self .model_ .vision_tower (
155+ # pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_pt_objects.LlavaForConditionalGeneration` has no attribute `device`.
148156 images .to (device = self .model_ .device , dtype = self .model_ .dtype ),
149157 output_hidden_states = True ,
150158 )
151159 image_features = self ._feature_select (image_forward_outs ).to (images .dtype )
160+ # pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_pt_objects.LlavaForConditionalGeneration` has no attribute `multi_modal_projector`.
152161 image_features = self .model_ .multi_modal_projector (image_features )
153162 return image_features
154163
155164 def image_preprocess (self , img : torch .Tensor ) -> torch .Tensor :
165+ # pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_vision_objects.CLIPImageProcessor` has no attribute `crop_size`.
156166 target_h = self .image_processor .crop_size ["height" ]
157167 target_w = self .image_processor .crop_size ["width" ]
158168 # pad the image with median rgb value, to make a square
@@ -195,10 +205,15 @@ def image_preprocess(self, img: torch.Tensor) -> torch.Tensor:
195205 # print(resized.shape)
196206 # cropped = F.center_crop(img, output_size=[w, w])
197207 # print(cropped.shape)
208+ # pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_vision_objects.CLIPImageProcessor` has no attribute `rescale_factor`.
198209 scaled = resized * self .image_processor .rescale_factor
199210 # print(scaled)
200211 normed = F .normalize (
201- scaled , self .image_processor .image_mean , self .image_processor .image_std
212+ scaled ,
213+ # pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_vision_objects.CLIPImageProcessor` has no attribute `image_mean`.
214+ self .image_processor .image_mean ,
215+ # pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_vision_objects.CLIPImageProcessor` has no attribute `image_std`.
216+ self .image_processor .image_std ,
202217 )
203218 # print(normed)
204219 return normed .unsqueeze (0 )
@@ -249,7 +264,9 @@ def prefill_ref(
249264 ) -> torch .Tensor :
250265 """Avoiding the torch.where() call to find <image> placeholder and insert image embedding. Taking 3 inputs instead."""
251266 embeds = self .prefill_embedding (prompt_before_image , images , prompt_after_image )
267+ # pyre-ignore: Undefined attribute [16]: Module `transformers` has no attribute `LlamaForCausalLM`.
252268 return LlamaForCausalLM .forward (
269+ # pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_pt_objects.LlavaForConditionalGeneration` has no attribute `language_model`.
253270 self .model_ .language_model ,
254271 inputs_embeds = embeds ,
255272 return_dict = False ,
@@ -268,12 +285,16 @@ class LlavaModel(EagerModelBase):
268285 def __init__ (self , use_sdpa_with_kv_cache_op = True , max_seq_len = 768 ):
269286 self .use_sdpa_with_kv_cache_op = use_sdpa_with_kv_cache_op
270287 self .max_seq_len = max_seq_len
271- self .processor = AutoProcessor .from_pretrained ("llava-hf/llava-1.5-7b-hf" )
288+ self .processor = AutoProcessor .from_pretrained (
289+ "llava-hf/llava-1.5-7b-hf" ,
290+ revision = "a272c74b2481d8aff3aa6fc2c4bf891fe57334fb" , # Need this for transformers >= 4.44.2
291+ )
272292 self .tokenizer = self .processor .tokenizer
273293 self .image_processor = self .processor .image_processor
274294 self .model = LlavaForConditionalGeneration .from_pretrained (
275295 "llava-hf/llava-1.5-7b-hf" ,
276296 device_map = "cpu" ,
297+ revision = "a272c74b2481d8aff3aa6fc2c4bf891fe57334fb" , # Need this for transformers >= 4.44.2
277298 )
278299 self .image = Image .open (
279300 requests .get (
0 commit comments