11import argparse
2+ import gc
3+ import logging
4+ import math
25import subprocess
36import tempfile
47from pathlib import Path
8+ from typing import List
59
610import torch
711from datasets import load_dataset
1519)
1620from transformers import (
1721 AutoConfig ,
22+ AutoModelForCausalLM ,
1823 AutoModelForImageClassification ,
1924 AutoProcessor ,
2025 AutoTokenizer ,
@@ -37,6 +42,56 @@ def cli_export(command, model_dir):
3742 print (f"Export failed with error: { e } " )
3843
3944
45+ def check_causal_lm_output_quality (
46+ model_id : str , generated_tokens : List [int ], max_perplexity_threshold : float = 100.0
47+ ):
48+ """
49+ Evaluates the quality of text generated by a causal language model by calculating its perplexity.
50+
51+ Args:
52+ model_id: HuggingFace model identifier (e.g., "google/gemma2-2b")
53+ generated_tokens: The tokens generated by the exported model to evaluate
54+ max_perplexity_threshold: Maximum acceptable perplexity (lower is better)
55+
56+ Returns:
57+ tuple: (is_quality_ok, reason) with boolean result and explanation
58+ """
59+ logging .info (f"Starting perplexity check with model '{ model_id } ' ..." )
60+ # Load model
61+ model = AutoModelForCausalLM .from_pretrained (
62+ model_id ,
63+ low_cpu_mem_usage = True ,
64+ use_cache = False ,
65+ torch_dtype = torch .bfloat16 ,
66+ )
67+
68+ with torch .no_grad ():
69+ outputs = model (input_ids = generated_tokens , labels = generated_tokens )
70+
71+ # Get the loss (negative log-likelihood)
72+ loss = outputs .loss .item ()
73+
74+ # Calculate perplexity (exp of the average negative log-likelihood)
75+ perplexity = math .exp (loss )
76+
77+ is_quality_ok = perplexity <= max_perplexity_threshold
78+ if is_quality_ok :
79+ logging .info (
80+ f"✓ Perplexity check passed: { perplexity :.2f} <= { max_perplexity_threshold } "
81+ )
82+ else :
83+ logging .warning (
84+ f"✗ Perplexity check failed: { perplexity :.2f} > { max_perplexity_threshold } "
85+ )
86+
87+ # Clean up immediately
88+ del model
89+ del outputs
90+ gc .collect ()
91+
92+ return is_quality_ok
93+
94+
4095def test_text_generation (model_id , model_dir , recipe , * , quantize = True , run_only = False ):
4196 command = [
4297 "optimum-cli" ,
@@ -51,7 +106,19 @@ def test_text_generation(model_id, model_dir, recipe, *, quantize=True, run_only
51106 "--output_dir" ,
52107 model_dir ,
53108 ]
54- if "coreml" in recipe :
109+ if "xnnpack" in recipe :
110+ command += [
111+ "--use_custom_sdpa" ,
112+ "--use_custom_kv_cache" ,
113+ ]
114+ if quantize :
115+ command += [
116+ "--qlinear" ,
117+ "8da4w" ,
118+ "--qembedding" ,
119+ "8w" ,
120+ ]
121+ elif "coreml" in recipe :
55122 command += [
56123 "--disable_dynamic_shapes" ,
57124 ]
@@ -63,7 +130,9 @@ def test_text_generation(model_id, model_dir, recipe, *, quantize=True, run_only
63130 "8w" ,
64131 ]
65132 else :
66- assert not quantize , "Quantization is not supported for non-CoreML recipes yet"
133+ assert (
134+ not quantize
135+ ), "Quantization is only supported for XnnPack and CoreML recipes at the moment."
67136
68137 if not run_only :
69138 cli_export (command , model_dir )
@@ -77,6 +146,14 @@ def test_text_generation(model_id, model_dir, recipe, *, quantize=True, run_only
77146 max_seq_len = 64 ,
78147 )
79148 print (f"\n Generated text:\n \t { generated_text } " )
149+ generated_tokens = tokenizer (generated_text , return_tensors = "pt" ).input_ids
150+
151+ # Free memory before loading eager for quality check
152+ del model
153+ del tokenizer
154+ gc .collect ()
155+
156+ assert check_causal_lm_output_quality (model_id , generated_tokens ) is True
80157
81158
82159def test_fill_mask (model_id , model_dir , recipe , * , quantize = True , run_only = False ):
@@ -278,23 +355,39 @@ def test_vit(model_id, model_dir, recipe, *, quantize=False, run_only=False):
278355 )
279356 args = parser .parse_args ()
280357
281- model_to_model_id_and_test_function = {
282- "smollm" : ("HuggingFaceTB/SmolLM2-135M" , test_text_generation ), # works
283- "qwen3" : ("Qwen/Qwen3-0.6B" , test_text_generation ), # works
284- "olmo" : ("allenai/OLMo-1B-hf" , test_text_generation ), # works
285- "gemma3" : ("unsloth/gemma-3-1b-it" , test_text_generation ), # does not export
286- "phi4" : (
358+ _text_generation_mapping = {
359+ "llama3.2-1b" : ("NousResearch/Llama-3.2-1B" , test_text_generation ),
360+ "qwen3-0.6b" : ("Qwen/Qwen3-0.6B" , test_text_generation ),
361+ "qwen3-1.7b" : ("Qwen/Qwen3-1.7B" , test_text_generation ),
362+ "gemma3-1b" : (
363+ "unsloth/gemma-3-1b-it" ,
364+ test_text_generation ,
365+ ), # does not export for CoreML
366+ "phi4-mini" : (
287367 "microsoft/Phi-4-mini-instruct" ,
288368 test_text_generation ,
289- ), # fails to lower
290- "llama3" : ("NousResearch/Llama-3.2-1B" , test_text_generation ), # works
291- "bert" : ("google-bert/bert-base-uncased" , test_fill_mask ), # works
292- "roberta" : ("FacebookAI/xlmcl-roberta-base" , test_fill_mask ), # works
293- "distilbert" : ("distilbert/distilbert-base-uncased" , test_fill_mask ), # works
294- "whisper" : ("openai/whisper-tiny" , test_whisper ), # works
369+ ), # fails to lower for CoreML
370+ "smollm2-135m" : ("HuggingFaceTB/SmolLM2-135M" , test_text_generation ),
371+ "smollm3-3b" : ("HuggingFaceTB/SmolLM3-3B" , test_text_generation ),
372+ "olmo-1b" : ("allenai/OLMo-1B-hf" , test_text_generation ),
373+ }
374+
375+ _mask_fill_mapping = {
376+ "bert" : ("google-bert/bert-base-uncased" , test_fill_mask ),
377+ "roberta" : ("FacebookAI/xlmcl-roberta-base" , test_fill_mask ),
378+ "distilbert" : ("distilbert/distilbert-base-uncased" , test_fill_mask ),
379+ }
380+
381+ _misc_model_mapping = {
382+ "whisper" : ("openai/whisper-tiny" , test_whisper ),
295383 "t5" : ("google-t5/t5-small" , test_t5 ), # CoreML runime failure
296- "vit" : ("google/vit-base-patch16-224" , test_vit ), # works
384+ "vit" : ("google/vit-base-patch16-224" , test_vit ),
297385 }
386+
387+ model_to_model_id_and_test_function = (
388+ _text_generation_mapping | _mask_fill_mapping | _misc_model_mapping
389+ )
390+
298391 if args .model not in model_to_model_id_and_test_function :
299392 raise ValueError (
300393 f"Unknown model name: { args .model } . Available models: { model_to_model_id_and_test_function .keys ()} "
0 commit comments