99
1010import argparse
1111import copy
12- import json
1312import logging
1413import os
1514
3130from executorch .backends .arm .tosa .partitioner import TOSAPartitioner
3231
3332from executorch .backends .arm .util .arm_model_evaluator import (
34- GenericModelEvaluator ,
35- MobileNetV2Evaluator ,
33+ evaluate_model ,
34+ evaluator_calibration_data ,
3635)
3736
3837from executorch .backends .arm .vgf import VgfCompileSpec , VgfPartitioner
@@ -188,46 +187,6 @@ def quantize(
188187 return m
189188
190189
191- # Simple example models
192- class AddModule (torch .nn .Module ):
193- def __init__ (self ):
194- super ().__init__ ()
195-
196- def forward (self , x ):
197- return x + x
198-
199- example_input = (torch .ones (5 , dtype = torch .int32 ),)
200- can_delegate = True
201-
202-
203- class AddModule2 (torch .nn .Module ):
204- def __init__ (self ):
205- super ().__init__ ()
206-
207- def forward (self , x , y ):
208- return x + y
209-
210- example_input = (
211- torch .ones (5 , dtype = torch .int32 ),
212- torch .ones (5 , dtype = torch .int32 ),
213- )
214- can_delegate = True
215-
216-
217- class AddModule3 (torch .nn .Module ):
218- def __init__ (self ):
219- super ().__init__ ()
220-
221- def forward (self , x , y ):
222- return (x + y , x + x )
223-
224- example_input = (
225- torch .ones (5 , dtype = torch .int32 ),
226- torch .ones (5 , dtype = torch .int32 ),
227- )
228- can_delegate = True
229-
230-
231190class QuantAddTest (torch .nn .Module ):
232191 def __init__ (self ):
233192 super ().__init__ ()
@@ -276,27 +235,6 @@ def forward(self, w, x, y, z):
276235 can_delegate = True # when quantized
277236
278237
279- class SoftmaxModule (torch .nn .Module ):
280- def __init__ (self ):
281- super ().__init__ ()
282- self .softmax = torch .nn .Softmax (dim = 0 )
283-
284- def forward (self , x ):
285- z = self .softmax (x )
286- return z
287-
288- example_input = (torch .ones (2 , 2 ),)
289- can_delegate = True
290-
291-
292- class MultipleOutputsModule (torch .nn .Module ):
293- def forward (self , x : torch .Tensor , y : torch .Tensor ):
294- return (x * y , x .sum (dim = - 1 , keepdim = True ))
295-
296- example_input = (torch .randn (10 , 4 , 5 ), torch .randn (10 , 4 , 5 ))
297- can_delegate = True
298-
299-
300238class QuantLinearTest (torch .nn .Module ):
301239 def __init__ (self ):
302240 super ().__init__ ()
@@ -311,29 +249,15 @@ def forward(self, x):
311249
312250
313251models = {
314- "add" : AddModule ,
315- "add2" : AddModule2 ,
316- "add3" : AddModule3 ,
317252 "qadd" : QuantAddTest ,
318253 "qadd2" : QuantAddTest2 ,
319254 "qops" : QuantOpTest ,
320- "softmax" : SoftmaxModule ,
321- "MultipleOutputsModule" : MultipleOutputsModule ,
322255 # TODO: Remove this from here, once we have dedicated MCU test pipeline ready. This is an interim solution.
323256 # See https://github.com/pytorch/executorch/discussions/13944
324257 "qlinear" : QuantLinearTest ,
325258}
326259
327260calibration_data = {
328- "add" : (torch .randn (1 , 5 ),),
329- "add2" : (
330- torch .randn (1 , 5 ),
331- torch .randn (1 , 5 ),
332- ),
333- "add3" : (
334- torch .randn (32 , 5 ),
335- torch .randn (32 , 5 ),
336- ),
337261 "qadd" : (torch .randn (32 , 2 , 1 ),),
338262 "qadd2" : (
339263 torch .randn (32 , 2 , 1 ),
@@ -345,13 +269,6 @@ def forward(self, x):
345269 torch .randn (32 , 2 , 1 ) * - 0.000001 ,
346270 torch .randn (32 , 2 , 1 ) * 1000 ,
347271 ),
348- "softmax" : (torch .randn (32 , 2 , 2 ),),
349- "qlinear" : (torch .randn (37 , 61 ),),
350- }
351-
352- evaluators = {
353- "generic" : GenericModelEvaluator ,
354- "mv2" : MobileNetV2Evaluator ,
355272}
356273
357274targets = [
@@ -378,21 +295,7 @@ def get_calibration_data(
378295):
379296 # Firstly, if the model is being evaluated, take the evaluators calibration function if it has one
380297 if evaluator_name is not None :
381- evaluator = evaluators [evaluator_name ]
382-
383- if hasattr (evaluator , "get_calibrator" ):
384- assert evaluator_config is not None
385-
386- config_path = Path (evaluator_config )
387- with config_path .open () as f :
388- config = json .load (f )
389-
390- if evaluator_name == "mv2" :
391- return evaluator .get_calibrator (
392- training_dataset_path = config ["training_dataset_path" ]
393- )
394- else :
395- raise RuntimeError (f"Unknown evaluator: { evaluator_name } " )
298+ return evaluator_calibration_data (evaluator_name , evaluator_config )
396299
397300 # If the model is in the calibration_data dictionary, get the data from there
398301 # This is used for the simple model examples provided
@@ -446,52 +349,6 @@ def get_compile_spec(
446349 return compile_spec
447350
448351
449- def evaluate_model (
450- model_name : str ,
451- intermediates : str ,
452- model_fp32 : torch .nn .Module ,
453- model_int8 : torch .nn .Module ,
454- example_inputs : Tuple [torch .Tensor ],
455- evaluator_name : str ,
456- evaluator_config : str | None ,
457- ) -> None :
458- evaluator = evaluators [evaluator_name ]
459-
460- # Get the path of the TOSA flatbuffer that is dumped
461- intermediates_path = Path (intermediates )
462- tosa_paths = list (intermediates_path .glob ("*.tosa" ))
463-
464- if evaluator .REQUIRES_CONFIG :
465- assert evaluator_config is not None
466-
467- config_path = Path (evaluator_config )
468- with config_path .open () as f :
469- config = json .load (f )
470-
471- if evaluator_name == "mv2" :
472- init_evaluator = evaluator (
473- model_name ,
474- model_fp32 ,
475- model_int8 ,
476- example_inputs ,
477- str (tosa_paths [0 ]),
478- config ["batch_size" ],
479- config ["validation_dataset_path" ],
480- )
481- else :
482- raise RuntimeError (f"Unknown evaluator { evaluator_name } " )
483- else :
484- init_evaluator = evaluator (
485- model_name , model_fp32 , model_int8 , example_inputs , str (tosa_paths [0 ])
486- )
487-
488- quant_metrics = init_evaluator .evaluate ()
489- output_json_path = intermediates_path / "quant_metrics.json"
490-
491- with output_json_path .open ("w" ) as json_file :
492- json .dump (quant_metrics , json_file )
493-
494-
495352def dump_delegation_info (edge , intermediate_files_folder : Optional [str ] = None ):
496353 graph_module = edge .exported_program ().graph_module
497354 delegation_info = get_delegation_info (graph_module )
0 commit comments