1
1
import argparse
2
+ import gc
3
+ import logging
4
+ import math
2
5
import subprocess
3
6
import tempfile
4
7
from pathlib import Path
8
+ from typing import List
5
9
6
10
import torch
7
11
from datasets import load_dataset
15
19
)
16
20
from transformers import (
17
21
AutoConfig ,
22
+ AutoModelForCausalLM ,
18
23
AutoModelForImageClassification ,
19
24
AutoProcessor ,
20
25
AutoTokenizer ,
@@ -37,6 +42,56 @@ def cli_export(command, model_dir):
37
42
print (f"Export failed with error: { e } " )
38
43
39
44
45
+ def check_causal_lm_output_quality (
46
+ model_id : str , generated_tokens : List [int ], max_perplexity_threshold : float = 100.0
47
+ ):
48
+ """
49
+ Evaluates the quality of text generated by a causal language model by calculating its perplexity.
50
+
51
+ Args:
52
+ model_id: HuggingFace model identifier (e.g., "google/gemma2-2b")
53
+ generated_tokens: The tokens generated by the exported model to evaluate
54
+ max_perplexity_threshold: Maximum acceptable perplexity (lower is better)
55
+
56
+ Returns:
57
+ tuple: (is_quality_ok, reason) with boolean result and explanation
58
+ """
59
+ logging .info (f"Starting perplexity check with model '{ model_id } ' ..." )
60
+ # Load model
61
+ model = AutoModelForCausalLM .from_pretrained (
62
+ model_id ,
63
+ low_cpu_mem_usage = True ,
64
+ use_cache = False ,
65
+ torch_dtype = torch .bfloat16 ,
66
+ )
67
+
68
+ with torch .no_grad ():
69
+ outputs = model (input_ids = generated_tokens , labels = generated_tokens )
70
+
71
+ # Get the loss (negative log-likelihood)
72
+ loss = outputs .loss .item ()
73
+
74
+ # Calculate perplexity (exp of the average negative log-likelihood)
75
+ perplexity = math .exp (loss )
76
+
77
+ is_quality_ok = perplexity <= max_perplexity_threshold
78
+ if is_quality_ok :
79
+ logging .info (
80
+ f"✓ Perplexity check passed: { perplexity :.2f} <= { max_perplexity_threshold } "
81
+ )
82
+ else :
83
+ logging .warning (
84
+ f"✗ Perplexity check failed: { perplexity :.2f} > { max_perplexity_threshold } "
85
+ )
86
+
87
+ # Clean up immediately
88
+ del model
89
+ del outputs
90
+ gc .collect ()
91
+
92
+ return is_quality_ok
93
+
94
+
40
95
def test_text_generation (model_id , model_dir , recipe , * , quantize = True , run_only = False ):
41
96
command = [
42
97
"optimum-cli" ,
@@ -51,7 +106,19 @@ def test_text_generation(model_id, model_dir, recipe, *, quantize=True, run_only
51
106
"--output_dir" ,
52
107
model_dir ,
53
108
]
54
- if "coreml" in recipe :
109
+ if "xnnpack" in recipe :
110
+ command += [
111
+ "--use_custom_sdpa" ,
112
+ "--use_custom_kv_cache" ,
113
+ ]
114
+ if quantize :
115
+ command += [
116
+ "--qlinear" ,
117
+ "8da4w" ,
118
+ "--qembedding" ,
119
+ "8w" ,
120
+ ]
121
+ elif "coreml" in recipe :
55
122
command += [
56
123
"--disable_dynamic_shapes" ,
57
124
]
@@ -63,7 +130,9 @@ def test_text_generation(model_id, model_dir, recipe, *, quantize=True, run_only
63
130
"8w" ,
64
131
]
65
132
else :
66
- assert not quantize , "Quantization is not supported for non-CoreML recipes yet"
133
+ assert (
134
+ not quantize
135
+ ), "Quantization is only supported for XnnPack and CoreML recipes at the moment."
67
136
68
137
if not run_only :
69
138
cli_export (command , model_dir )
@@ -77,6 +146,14 @@ def test_text_generation(model_id, model_dir, recipe, *, quantize=True, run_only
77
146
max_seq_len = 64 ,
78
147
)
79
148
print (f"\n Generated text:\n \t { generated_text } " )
149
+ generated_tokens = tokenizer (generated_text , return_tensors = "pt" ).input_ids
150
+
151
+ # Free memory before loading eager for quality check
152
+ del model
153
+ del tokenizer
154
+ gc .collect ()
155
+
156
+ assert check_causal_lm_output_quality (model_id , generated_tokens ) is True
80
157
81
158
82
159
def test_fill_mask (model_id , model_dir , recipe , * , quantize = True , run_only = False ):
@@ -278,23 +355,39 @@ def test_vit(model_id, model_dir, recipe, *, quantize=False, run_only=False):
278
355
)
279
356
args = parser .parse_args ()
280
357
281
- model_to_model_id_and_test_function = {
282
- "smollm" : ("HuggingFaceTB/SmolLM2-135M" , test_text_generation ), # works
283
- "qwen3" : ("Qwen/Qwen3-0.6B" , test_text_generation ), # works
284
- "olmo" : ("allenai/OLMo-1B-hf" , test_text_generation ), # works
285
- "gemma3" : ("unsloth/gemma-3-1b-it" , test_text_generation ), # does not export
286
- "phi4" : (
358
+ _text_generation_mapping = {
359
+ "llama3.2-1b" : ("NousResearch/Llama-3.2-1B" , test_text_generation ),
360
+ "qwen3-0.6b" : ("Qwen/Qwen3-0.6B" , test_text_generation ),
361
+ "qwen3-1.7b" : ("Qwen/Qwen3-1.7B" , test_text_generation ),
362
+ "gemma3-1b" : (
363
+ "unsloth/gemma-3-1b-it" ,
364
+ test_text_generation ,
365
+ ), # does not export for CoreML
366
+ "phi4-mini" : (
287
367
"microsoft/Phi-4-mini-instruct" ,
288
368
test_text_generation ,
289
- ), # fails to lower
290
- "llama3" : ("NousResearch/Llama-3.2-1B" , test_text_generation ), # works
291
- "bert" : ("google-bert/bert-base-uncased" , test_fill_mask ), # works
292
- "roberta" : ("FacebookAI/xlmcl-roberta-base" , test_fill_mask ), # works
293
- "distilbert" : ("distilbert/distilbert-base-uncased" , test_fill_mask ), # works
294
- "whisper" : ("openai/whisper-tiny" , test_whisper ), # works
369
+ ), # fails to lower for CoreML
370
+ "smollm2-135m" : ("HuggingFaceTB/SmolLM2-135M" , test_text_generation ),
371
+ "smollm3-3b" : ("HuggingFaceTB/SmolLM3-3B" , test_text_generation ),
372
+ "olmo" : ("allenai/OLMo-1B-hf" , test_text_generation ),
373
+ }
374
+
375
+ _mask_fill_mapping = {
376
+ "bert" : ("google-bert/bert-base-uncased" , test_fill_mask ),
377
+ "roberta" : ("FacebookAI/xlmcl-roberta-base" , test_fill_mask ),
378
+ "distilbert" : ("distilbert/distilbert-base-uncased" , test_fill_mask ),
379
+ }
380
+
381
+ _misc_model_mapping = {
382
+ "whisper" : ("openai/whisper-tiny" , test_whisper ),
295
383
"t5" : ("google-t5/t5-small" , test_t5 ), # CoreML runime failure
296
- "vit" : ("google/vit-base-patch16-224" , test_vit ), # works
384
+ "vit" : ("google/vit-base-patch16-224" , test_vit ),
297
385
}
386
+
387
+ model_to_model_id_and_test_function = (
388
+ _text_generation_mapping | _mask_fill_mapping | _misc_model_mapping
389
+ )
390
+
298
391
if args .model not in model_to_model_id_and_test_function :
299
392
raise ValueError (
300
393
f"Unknown model name: { args .model } . Available models: { model_to_model_id_and_test_function .keys ()} "
0 commit comments