@@ -43,7 +43,9 @@ def cli_export(command, model_dir):
4343
4444
4545def  check_causal_lm_output_quality (
46-     model_id : str , generated_tokens : List [int ], max_perplexity_threshold : float  =  100.0 
46+     model_id : str ,
47+     generated_tokens : List [int ],
48+     max_perplexity_threshold : float  =  100.0 ,
4749):
4850    """ 
4951    Evaluates the quality of text generated by a causal language model by calculating its perplexity. 
@@ -58,12 +60,24 @@ def check_causal_lm_output_quality(
5860    """ 
5961    logging .info (f"Starting perplexity check with model '{ model_id }  ' ..." )
6062    # Load model 
61-     model  =  AutoModelForCausalLM .from_pretrained (
62-         model_id ,
63-         low_cpu_mem_usage = True ,
64-         use_cache = False ,
65-         torch_dtype = torch .bfloat16 ,
66-     )
63+     cls_name  =  AutoModelForCausalLM 
64+     if  "llava"  in  model_id :
65+         from  transformers  import  LlavaForConditionalGeneration 
66+ 
67+         cls_name  =  LlavaForConditionalGeneration 
68+     try :
69+         model  =  cls_name .from_pretrained (
70+             model_id ,
71+             low_cpu_mem_usage = True ,
72+             use_cache = False ,
73+             torch_dtype = torch .bfloat16 ,
74+         )
75+     except  TypeError :
76+         model  =  cls_name .from_pretrained (
77+             model_id ,
78+             low_cpu_mem_usage = True ,
79+             torch_dtype = torch .bfloat16 ,
80+         )
6781
6882    with  torch .no_grad ():
6983        outputs  =  model (input_ids = generated_tokens , labels = generated_tokens )
@@ -156,6 +170,86 @@ def test_text_generation(model_id, model_dir, recipe, *, quantize=True, run_only
156170    assert  check_causal_lm_output_quality (model_id , generated_tokens ) is  True 
157171
158172
173+ def  test_llm_with_image_modality (
174+     model_id , model_dir , recipe , * , quantize = True , run_only = False 
175+ ):
176+     command  =  [
177+         "optimum-cli" ,
178+         "export" ,
179+         "executorch" ,
180+         "--model" ,
181+         model_id ,
182+         "--task" ,
183+         "multimodal-text-to-text" ,
184+         "--recipe" ,
185+         recipe ,
186+         "--output_dir" ,
187+         model_dir ,
188+         "--use_custom_sdpa" ,
189+         "--use_custom_kv_cache" ,
190+         "--qlinear" ,
191+         "8da4w" ,
192+         "--qembedding" ,
193+         "8w" ,
194+     ]
195+     if  not  run_only :
196+         cli_export (command , model_dir )
197+ 
198+     tokenizer  =  AutoTokenizer .from_pretrained (model_id )
199+     tokenizer .save_pretrained (model_dir )
200+ 
201+     # input 
202+     processor  =  AutoProcessor .from_pretrained (model_id )
203+     image_url  =  "https://llava-vl.github.io/static/images/view.jpg" 
204+     conversation  =  [
205+         {
206+             "role" : "system" ,
207+             "content" : [
208+                 {
209+                     "type" : "text" ,
210+                     "text" : "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." ,
211+                 }
212+             ],
213+         },
214+         {
215+             "role" : "user" ,
216+             "content" : [
217+                 {"type" : "image" , "url" : image_url },
218+                 {
219+                     "type" : "text" ,
220+                     "text" : "What are the things I should be cautious about when I visit here?" ,
221+                 },
222+             ],
223+         },
224+     ]
225+     inputs  =  processor .apply_chat_template (
226+         conversation ,
227+         add_generation_prompt = True ,
228+         tokenize = True ,
229+         return_dict = True ,
230+         return_tensors = "pt" ,
231+     )
232+ 
233+     from  executorch .extension .llm .runner  import  GenerationConfig , MultimodalRunner 
234+ 
235+     runner  =  MultimodalRunner (f"{ model_dir }  /model.pte" , f"{ model_dir }  /tokenizer.model" )
236+     generated_text  =  runner .generate_text_hf (
237+         inputs ,
238+         GenerationConfig (max_new_tokens = 128 , temperature = 0 , echo = False ),
239+         processor .image_token_id ,
240+     )
241+     print (f"\n Generated text:\n \t { generated_text }  " )
242+     # Free memory before loading eager for quality check 
243+     del  runner 
244+     gc .collect ()
245+     assert  (
246+         check_causal_lm_output_quality (
247+             model_id , tokenizer .encode (generated_text , return_tensors = "pt" )
248+         )
249+         is  True 
250+     )
251+ 
252+ 
159253def  test_fill_mask (model_id , model_dir , recipe , * , quantize = True , run_only = False ):
160254    command  =  [
161255        "optimum-cli" ,
@@ -353,6 +447,9 @@ def test_vit(model_id, model_dir, recipe, *, quantize=False, run_only=False):
353447        required = False ,
354448        help = "When provided, write the pte file to this directory. Otherwise, a temporary directory is created for the test." ,
355449    )
450+     parser .add_argument (
451+         "--run_only" , action = "store_true" , help = "Skip export and only run the test" 
452+     )
356453    args  =  parser .parse_args ()
357454
358455    _text_generation_mapping  =  {
@@ -384,8 +481,16 @@ def test_vit(model_id, model_dir, recipe, *, quantize=False, run_only=False):
384481        "vit" : ("google/vit-base-patch16-224" , test_vit ),
385482    }
386483
484+     _multimodal_model_mapping  =  {
485+         "gemma3-4b" : ("google/gemma-3-4b-it" , test_llm_with_image_modality ),
486+         "llava" : ("llava-hf/llava-1.5-7b-hf" , test_llm_with_image_modality ),
487+     }
488+ 
387489    model_to_model_id_and_test_function  =  (
388-         _text_generation_mapping  |  _mask_fill_mapping  |  _misc_model_mapping 
490+         _text_generation_mapping 
491+         |  _mask_fill_mapping 
492+         |  _misc_model_mapping 
493+         |  _multimodal_model_mapping 
389494    )
390495
391496    if  args .model  not  in   model_to_model_id_and_test_function :
@@ -400,4 +505,5 @@ def test_vit(model_id, model_dir, recipe, *, quantize=False, run_only=False):
400505            model_dir = tmp_dir  if  args .model_dir  is  None  else  args .model_dir ,
401506            recipe = args .recipe ,
402507            quantize = args .quantize ,
508+             run_only = args .run_only ,
403509        )
0 commit comments