@@ -277,6 +277,105 @@ def get_op_nodes_not_followed_by_specific_op(model, op1, op2):
277277
278278 return not_selected_op1_nodes
279279
280+ def custom_write_calibration_table (calibration_cache , dir = "." ):
281+ """
282+ Helper function to write calibration table to files.
283+ """
284+
285+ import json
286+ import logging
287+ import flatbuffers
288+ import numpy as np
289+
290+ import onnxruntime .quantization .CalTableFlatBuffers .KeyValue as KeyValue
291+ import onnxruntime .quantization .CalTableFlatBuffers .TrtTable as TrtTable
292+ from onnxruntime .quantization .calibrate import CalibrationMethod , TensorData , TensorsData
293+
294+ logging .info (f"calibration cache: { calibration_cache } " )
295+
296+ class MyEncoder (json .JSONEncoder ):
297+ def default (self , obj ):
298+ if isinstance (obj , (TensorData , TensorsData )):
299+ return obj .to_dict ()
300+ if isinstance (obj , TensorDataWrapper ):
301+ return obj .data_dict
302+ if isinstance (obj , np .ndarray ):
303+ return {"data" : obj .tolist (), "dtype" : str (obj .dtype ), "CLS" : "numpy.array" }
304+ if isinstance (obj , CalibrationMethod ):
305+ return {"CLS" : obj .__class__ .__name__ , "value" : str (obj )}
306+ return json .JSONEncoder .default (self , obj )
307+
308+ json_data = json .dumps (calibration_cache , cls = MyEncoder )
309+
310+ with open (os .path .join (dir , "calibration.json" ), "w" ) as file :
311+ file .write (json_data ) # use `json.loads` to do the reverse
312+
313+ # Serialize data using FlatBuffers
314+ zero = np .array (0 )
315+ builder = flatbuffers .Builder (1024 )
316+ key_value_list = []
317+
318+ for key in sorted (calibration_cache .keys ()):
319+ values = calibration_cache [key ]
320+ d_values = values .to_dict ()
321+
322+ highest = d_values .get ("highest" , zero )
323+ lowest = d_values .get ("lowest" , zero )
324+
325+ highest_val = highest .item () if hasattr (highest , "item" ) else float (highest )
326+ lowest_val = lowest .item () if hasattr (lowest , "item" ) else float (lowest )
327+
328+ floats = [float (highest_val ), float (lowest_val )]
329+
330+ value = str (max (floats ))
331+
332+ flat_key = builder .CreateString (key )
333+ flat_value = builder .CreateString (value )
334+
335+ KeyValue .KeyValueStart (builder )
336+ KeyValue .KeyValueAddKey (builder , flat_key )
337+ KeyValue .KeyValueAddValue (builder , flat_value )
338+ key_value = KeyValue .KeyValueEnd (builder )
339+
340+ key_value_list .append (key_value )
341+
342+
343+ TrtTable .TrtTableStartDictVector (builder , len (key_value_list ))
344+ for key_value in key_value_list :
345+ builder .PrependUOffsetTRelative (key_value )
346+ main_dict = builder .EndVector ()
347+
348+ TrtTable .TrtTableStart (builder )
349+ TrtTable .TrtTableAddDict (builder , main_dict )
350+ cal_table = TrtTable .TrtTableEnd (builder )
351+
352+ builder .Finish (cal_table )
353+ buf = builder .Output ()
354+
355+ with open (os .path .join (dir , "calibration.flatbuffers" ), "wb" ) as file :
356+ file .write (buf )
357+
358+ # Deserialize data (for validation)
359+ if os .environ .get ("QUANTIZATION_DEBUG" , 0 ) in (1 , "1" ):
360+ cal_table = TrtTable .TrtTable .GetRootAsTrtTable (buf , 0 )
361+ dict_len = cal_table .DictLength ()
362+ for i in range (dict_len ):
363+ key_value = cal_table .Dict (i )
364+ logging .info (key_value .Key ())
365+ logging .info (key_value .Value ())
366+
367+ # write plain text
368+ with open (os .path .join (dir , "calibration.cache" ), "w" ) as file :
369+ for key in sorted (calibration_cache .keys ()):
370+ values = calibration_cache [key ]
371+ d_values = values .to_dict ()
372+ floats = [
373+ float (d_values .get ("highest" , zero ).item ()),
374+ float (d_values .get ("lowest" , zero ).item ()),
375+ ]
376+ value = key + " " + str (max (floats ))
377+ file .write (value )
378+ file .write ("\n " )
280379
281380def parse_input_args ():
282381 parser = argparse .ArgumentParser ()
@@ -553,8 +652,42 @@ def output_run_config(flags, samples):
553652 for k , v in compute_range .data .items ():
554653 json_compute_range [k ] = (float (v .range_value [0 ]), float (v .range_value [1 ]))
555654
655+ print ("Writing calibration table" )
656+ try :
657+ write_calibration_table (json_compute_range )
658+ except AttributeError as e :
659+ class TensorDataWrapper :
660+ def __init__ (self , data_dict ):
661+ self .data_dict = data_dict
662+
663+ def to_dict (self ):
664+ return self .data_dict
665+
666+ def __repr__ (self ):
667+ return repr (self .data_dict )
668+
669+ def __serializable__ (self ):
670+ return self .data_dict
671+
672+ calibration_data = {}
673+ for k , v in compute_range .data .items ():
674+ if hasattr (v , 'to_dict' ):
675+ tensor_dict = v .to_dict ()
676+ processed_dict = {}
677+ for dk , dv in tensor_dict .items ():
678+ if isinstance (dv , np .ndarray ):
679+ processed_dict [dk ] = dv .item () if dv .size == 1 else dv .tolist ()
680+ elif isinstance (dv , np .number ):
681+ processed_dict [dk ] = dv .item ()
682+ else :
683+ processed_dict [dk ] = dv
684+ calibration_data [k ] = TensorDataWrapper (processed_dict )
685+ else :
686+ calibration_data [k ] = v
687+
688+ print ("Using custom calibration table function" )
689+ custom_write_calibration_table (calibration_data )
556690
557- write_calibration_table (json_compute_range )
558691 print ("Calibration is done. Calibration cache is saved to calibration.json" )
559692
560693 model_quants = model_quants + "_int8"
0 commit comments