@@ -43,7 +43,9 @@ def cli_export(command, model_dir):
4343
4444
4545def check_causal_lm_output_quality (
46- model_id : str , generated_tokens : List [int ], max_perplexity_threshold : float = 100.0
46+ model_id : str ,
47+ generated_tokens : List [int ],
48+ max_perplexity_threshold : float = 100.0 ,
4749):
4850 """
4951 Evaluates the quality of text generated by a causal language model by calculating its perplexity.
@@ -58,12 +60,24 @@ def check_causal_lm_output_quality(
5860 """
5961 logging .info (f"Starting perplexity check with model '{ model_id } ' ..." )
6062 # Load model
61- model = AutoModelForCausalLM .from_pretrained (
62- model_id ,
63- low_cpu_mem_usage = True ,
64- use_cache = False ,
65- torch_dtype = torch .bfloat16 ,
66- )
63+ cls_name = AutoModelForCausalLM
64+ if "llava" in model_id :
65+ from transformers import LlavaForConditionalGeneration
66+
67+ cls_name = LlavaForConditionalGeneration
68+ try :
69+ model = cls_name .from_pretrained (
70+ model_id ,
71+ low_cpu_mem_usage = True ,
72+ use_cache = False ,
73+ torch_dtype = torch .bfloat16 ,
74+ )
75+ except TypeError :
76+ model = cls_name .from_pretrained (
77+ model_id ,
78+ low_cpu_mem_usage = True ,
79+ torch_dtype = torch .bfloat16 ,
80+ )
6781
6882 with torch .no_grad ():
6983 outputs = model (input_ids = generated_tokens , labels = generated_tokens )
@@ -156,6 +170,105 @@ def test_text_generation(model_id, model_dir, recipe, *, quantize=True, run_only
156170 assert check_causal_lm_output_quality (model_id , generated_tokens ) is True
157171
158172
173+ def test_llm_with_image_modality (
174+ model_id , model_dir , recipe , * , quantize = True , run_only = False
175+ ):
176+ command = [
177+ "optimum-cli" ,
178+ "export" ,
179+ "executorch" ,
180+ "--model" ,
181+ model_id ,
182+ "--task" ,
183+ "multimodal-text-to-text" ,
184+ "--recipe" ,
185+ recipe ,
186+ "--output_dir" ,
187+ model_dir ,
188+ "--use_custom_sdpa" ,
189+ "--use_custom_kv_cache" ,
190+ "--qlinear" ,
191+ "8da4w" ,
192+ "--qembedding" ,
193+ "8w" ,
194+ ]
195+ if not run_only :
196+ cli_export (command , model_dir )
197+
198+ tokenizer = AutoTokenizer .from_pretrained (model_id )
199+ tokenizer .save_pretrained (model_dir )
200+
201+ # input
202+ processor = AutoProcessor .from_pretrained (model_id )
203+ image_url = "https://llava-vl.github.io/static/images/view.jpg"
204+ conversation = [
205+ {
206+ "role" : "system" ,
207+ "content" : [
208+ {
209+ "type" : "text" ,
210+ "text" : "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." ,
211+ }
212+ ],
213+ },
214+ {
215+ "role" : "user" ,
216+ "content" : [
217+ {"type" : "image" , "url" : image_url },
218+ {
219+ "type" : "text" ,
220+ "text" : "What are the things I should be cautious about when I visit here?" ,
221+ },
222+ ],
223+ },
224+ ]
225+ inputs = processor .apply_chat_template (
226+ conversation ,
227+ add_generation_prompt = True ,
228+ tokenize = True ,
229+ return_dict = True ,
230+ return_tensors = "pt" ,
231+ )
232+
233+ import torch
234+
235+ first_image_id_index = torch .where (inputs ["input_ids" ] == processor .image_token_id )[
236+ 1
237+ ][0 ].item ()
238+ last_image_id_index = torch .where (inputs ["input_ids" ] == processor .image_token_id )[
239+ 1
240+ ][- 1 ].item ()
241+
242+ prompt_before_image = inputs ["input_ids" ][0 , :first_image_id_index ]
243+ prompt_after_image = inputs ["input_ids" ][0 , last_image_id_index + 1 :]
244+ from executorch .extension .llm .runner import (
245+ GenerationConfig ,
246+ make_image_input ,
247+ make_token_input ,
248+ MultimodalRunner ,
249+ )
250+
251+ combined_inputs = [
252+ make_token_input (prompt_before_image .tolist ()),
253+ make_image_input (inputs ["pixel_values" ]),
254+ make_token_input (prompt_after_image .tolist ()),
255+ ]
256+ runner = MultimodalRunner (f"{ model_dir } /model.pte" , f"{ model_dir } /tokenizer.model" )
257+ generated_text = runner .generate_text (
258+ combined_inputs , GenerationConfig (max_new_tokens = 128 , temperature = 0 , echo = False )
259+ )
260+ print (f"\n Generated text:\n \t { generated_text } " )
261+ # Free memory before loading eager for quality check
262+ del runner
263+ gc .collect ()
264+ assert (
265+ check_causal_lm_output_quality (
266+ model_id , tokenizer .encode (generated_text , return_tensors = "pt" )
267+ )
268+ is True
269+ )
270+
271+
159272def test_fill_mask (model_id , model_dir , recipe , * , quantize = True , run_only = False ):
160273 command = [
161274 "optimum-cli" ,
@@ -353,6 +466,9 @@ def test_vit(model_id, model_dir, recipe, *, quantize=False, run_only=False):
353466 required = False ,
354467 help = "When provided, write the pte file to this directory. Otherwise, a temporary directory is created for the test." ,
355468 )
469+ parser .add_argument (
470+ "--run_only" , action = "store_true" , help = "Skip export and only run the test"
471+ )
356472 args = parser .parse_args ()
357473
358474 _text_generation_mapping = {
@@ -384,8 +500,16 @@ def test_vit(model_id, model_dir, recipe, *, quantize=False, run_only=False):
384500 "vit" : ("google/vit-base-patch16-224" , test_vit ),
385501 }
386502
503+ _multimodal_model_mapping = {
504+ "gemma3-4b" : ("google/gemma-3-4b-it" , test_llm_with_image_modality ),
505+ "llava" : ("llava-hf/llava-1.5-7b-hf" , test_llm_with_image_modality ),
506+ }
507+
387508 model_to_model_id_and_test_function = (
388- _text_generation_mapping | _mask_fill_mapping | _misc_model_mapping
509+ _text_generation_mapping
510+ | _mask_fill_mapping
511+ | _misc_model_mapping
512+ | _multimodal_model_mapping
389513 )
390514
391515 if args .model not in model_to_model_id_and_test_function :
@@ -400,4 +524,5 @@ def test_vit(model_id, model_dir, recipe, *, quantize=False, run_only=False):
400524 model_dir = tmp_dir if args .model_dir is None else args .model_dir ,
401525 recipe = args .recipe ,
402526 quantize = args .quantize ,
527+ run_only = args .run_only ,
403528 )
0 commit comments