@@ -43,7 +43,9 @@ def cli_export(command, model_dir):
4343
4444
4545def check_causal_lm_output_quality (
46- model_id : str , generated_tokens : List [int ], max_perplexity_threshold : float = 100.0
46+ model_id : str ,
47+ generated_tokens : List [int ],
48+ max_perplexity_threshold : float = 100.0 ,
4749):
4850 """
4951 Evaluates the quality of text generated by a causal language model by calculating its perplexity.
@@ -58,12 +60,24 @@ def check_causal_lm_output_quality(
5860 """
5961 logging .info (f"Starting perplexity check with model '{ model_id } ' ..." )
6062 # Load model
61- model = AutoModelForCausalLM .from_pretrained (
62- model_id ,
63- low_cpu_mem_usage = True ,
64- use_cache = False ,
65- torch_dtype = torch .bfloat16 ,
66- )
63+ cls_name = AutoModelForCausalLM
64+ if "llava" in model_id :
65+ from transformers import LlavaForConditionalGeneration
66+
67+ cls_name = LlavaForConditionalGeneration
68+ try :
69+ model = cls_name .from_pretrained (
70+ model_id ,
71+ low_cpu_mem_usage = True ,
72+ use_cache = False ,
73+ torch_dtype = torch .bfloat16 ,
74+ )
75+ except TypeError :
76+ model = cls_name .from_pretrained (
77+ model_id ,
78+ low_cpu_mem_usage = True ,
79+ torch_dtype = torch .bfloat16 ,
80+ )
6781
6882 with torch .no_grad ():
6983 outputs = model (input_ids = generated_tokens , labels = generated_tokens )
@@ -156,6 +170,86 @@ def test_text_generation(model_id, model_dir, recipe, *, quantize=True, run_only
156170 assert check_causal_lm_output_quality (model_id , generated_tokens ) is True
157171
158172
173+ def test_llm_with_image_modality (
174+ model_id , model_dir , recipe , * , quantize = True , run_only = False
175+ ):
176+ command = [
177+ "optimum-cli" ,
178+ "export" ,
179+ "executorch" ,
180+ "--model" ,
181+ model_id ,
182+ "--task" ,
183+ "multimodal-text-to-text" ,
184+ "--recipe" ,
185+ recipe ,
186+ "--output_dir" ,
187+ model_dir ,
188+ "--use_custom_sdpa" ,
189+ "--use_custom_kv_cache" ,
190+ "--qlinear" ,
191+ "8da4w" ,
192+ "--qembedding" ,
193+ "8w" ,
194+ ]
195+ if not run_only :
196+ cli_export (command , model_dir )
197+
198+ tokenizer = AutoTokenizer .from_pretrained (model_id )
199+ tokenizer .save_pretrained (model_dir )
200+
201+ # input
202+ processor = AutoProcessor .from_pretrained (model_id )
203+ image_url = "https://llava-vl.github.io/static/images/view.jpg"
204+ conversation = [
205+ {
206+ "role" : "system" ,
207+ "content" : [
208+ {
209+ "type" : "text" ,
210+ "text" : "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." ,
211+ }
212+ ],
213+ },
214+ {
215+ "role" : "user" ,
216+ "content" : [
217+ {"type" : "image" , "url" : image_url },
218+ {
219+ "type" : "text" ,
220+ "text" : "What are the things I should be cautious about when I visit here?" ,
221+ },
222+ ],
223+ },
224+ ]
225+ inputs = processor .apply_chat_template (
226+ conversation ,
227+ add_generation_prompt = True ,
228+ tokenize = True ,
229+ return_dict = True ,
230+ return_tensors = "pt" ,
231+ )
232+
233+ from executorch .extension .llm .runner import GenerationConfig , MultimodalRunner
234+
235+ runner = MultimodalRunner (f"{ model_dir } /model.pte" , f"{ model_dir } /tokenizer.model" )
236+ generated_text = runner .generate_text_hf (
237+ inputs ,
238+ GenerationConfig (max_new_tokens = 128 , temperature = 0 , echo = False ),
239+ processor .image_token_id ,
240+ )
241+ print (f"\n Generated text:\n \t { generated_text } " )
242+ # Free memory before loading eager for quality check
243+ del runner
244+ gc .collect ()
245+ assert (
246+ check_causal_lm_output_quality (
247+ model_id , tokenizer .encode (generated_text , return_tensors = "pt" )
248+ )
249+ is True
250+ )
251+
252+
159253def test_fill_mask (model_id , model_dir , recipe , * , quantize = True , run_only = False ):
160254 command = [
161255 "optimum-cli" ,
@@ -353,6 +447,9 @@ def test_vit(model_id, model_dir, recipe, *, quantize=False, run_only=False):
353447 required = False ,
354448 help = "When provided, write the pte file to this directory. Otherwise, a temporary directory is created for the test." ,
355449 )
450+ parser .add_argument (
451+ "--run_only" , action = "store_true" , help = "Skip export and only run the test"
452+ )
356453 args = parser .parse_args ()
357454
358455 _text_generation_mapping = {
@@ -384,8 +481,16 @@ def test_vit(model_id, model_dir, recipe, *, quantize=False, run_only=False):
384481 "vit" : ("google/vit-base-patch16-224" , test_vit ),
385482 }
386483
484+ _multimodal_model_mapping = {
485+ "gemma3-4b" : ("google/gemma-3-4b-it" , test_llm_with_image_modality ),
486+ "llava" : ("llava-hf/llava-1.5-7b-hf" , test_llm_with_image_modality ),
487+ }
488+
387489 model_to_model_id_and_test_function = (
388- _text_generation_mapping | _mask_fill_mapping | _misc_model_mapping
490+ _text_generation_mapping
491+ | _mask_fill_mapping
492+ | _misc_model_mapping
493+ | _multimodal_model_mapping
389494 )
390495
391496 if args .model not in model_to_model_id_and_test_function :
@@ -400,4 +505,5 @@ def test_vit(model_id, model_dir, recipe, *, quantize=False, run_only=False):
400505 model_dir = tmp_dir if args .model_dir is None else args .model_dir ,
401506 recipe = args .recipe ,
402507 quantize = args .quantize ,
508+ run_only = args .run_only ,
403509 )
0 commit comments