@@ -306,20 +306,19 @@ def is_float_tensor(self, tensor_name):
306
306
)
307
307
return False
308
308
309
- def _get_dynamic_input_quantization_params (self , input_name , nodes_list , qType ):
309
+ def _get_dynamic_input_quantization_params (self , input_name , nodes_list , qType , initial_type ):
310
310
"""
311
311
Create nodes for dynamic quantization of input and add them to nodes_list.
312
312
parameter input_name: Name of the input.
313
313
parameter nodes_list: new nodes are appended to this list.
314
314
parameter qType: type to quantize to.
315
+ parameter initial_type: type to quantize from
315
316
return: scale_name, zero_point_name, scale_shape, zero_point_shape.
316
317
"""
317
318
if qType == onnx_proto .TensorProto .INT8 :
318
- return self ._get_dynamic_input_quantization_params_int8 (input_name , nodes_list )
319
+ return self ._get_dynamic_input_quantization_params_int8 (input_name , nodes_list , initial_type )
319
320
if qType == onnx_proto .TensorProto .UINT8 :
320
- return self ._get_dynamic_input_quantization_params_uint8 (input_name , nodes_list )
321
- if qType == onnx_proto .TensorProto .FLOAT8E4M3FN :
322
- return self ._get_dynamic_input_quantization_params_float8e4m3fn (input_name , nodes_list )
321
+ return self ._get_dynamic_input_quantization_params_uint8 (input_name , nodes_list , initial_type )
323
322
raise ValueError (f"Unexpected value for qType={ qType } ." )
324
323
325
324
def _get_dynamic_input_quantization_params_int8 (self , input_name , nodes_list , initial_type ):
@@ -559,7 +558,9 @@ def _get_quantization_params(self, param_name, use_scale=None, use_zeropoint=Non
559
558
560
559
return True , scale_name , zero_point_name , scale_shape , zero_point_shape
561
560
562
- def _get_quantize_input_nodes (self , node , input_index , qType , given_scale_name = None , given_zp_name = None ):
561
+ def _get_quantize_input_nodes (
562
+ self , node , input_index , qType , given_scale_name = None , given_zp_name = None , initial_type = None
563
+ ):
563
564
"""
564
565
Given an input for a node (which is not a initializer), this function
565
566
@@ -571,6 +572,7 @@ def _get_quantize_input_nodes(self, node, input_index, qType, given_scale_name=N
571
572
:param qType: type to quantize to.
572
573
:param given_scale_name: if those inputs need to be quanitzed using this scale tensor.
573
574
:param given_zp_name: if those inputs to be quantized using this zeropoint tensor.
575
+ :param initial_type: type of the weight to quantize
574
576
:return: List of newly created nodes in NodeProto format.
575
577
"""
576
578
input_name = node .input [input_index ]
@@ -606,12 +608,16 @@ def _get_quantize_input_nodes(self, node, input_index, qType, given_scale_name=N
606
608
ql_node_name ,
607
609
)
608
610
else :
611
+ assert initial_type is not None , (
612
+ f"Cannot quantize input without knowing the initial type, "
613
+ f"input_name={ input_name !r} , input_index={ input_index } , qType={ qType } , node={ node } "
614
+ )
609
615
(
610
616
scale_name ,
611
617
zp_name ,
612
618
scale_shape ,
613
619
zp_shape ,
614
- ) = self ._get_dynamic_input_quantization_params (input_name , nodes , qType )
620
+ ) = self ._get_dynamic_input_quantization_params (input_name , nodes , qType , initial_type = initial_type )
615
621
qlinear_node = onnx .helper .make_node (
616
622
"QuantizeLinear" ,
617
623
[input_name , scale_name , zp_name ],
@@ -794,7 +800,23 @@ def __quantize_inputs(
794
800
node_input + "_QuantizeLinear" , self .new_nodes , self .model .graph ()
795
801
)
796
802
if qlinear_node is None :
797
- quantize_input_nodes = self ._get_quantize_input_nodes (node , input_index , self .activation_qType )
803
+ input_name = node .input [input_index ]
804
+ if input_name in self .value_infos :
805
+ value_info = self .value_infos [input_name ]
806
+ assert value_info .HasField ("type" ), f"value_info={ value_info } has no type."
807
+ assert value_info .type .HasField ("tensor_type" ), f"value_info={ value_info } is not a tensor."
808
+ initial_type = value_info .type .tensor_type .elem_type
809
+ else :
810
+ # Shape inference failed. Fallback to self.tensor_names.
811
+ assert input_name in self .tensor_names , (
812
+ f"shape inference failed for { input_name !r} and "
813
+ f"attribute 'tensor_names' does not have any value for "
814
+ f"this tensor."
815
+ )
816
+ initial_type = self .tensor_names [input_name ]
817
+ quantize_input_nodes = self ._get_quantize_input_nodes (
818
+ node , input_index , self .activation_qType , initial_type = initial_type
819
+ )
798
820
if quantize_input_nodes is None :
799
821
return (None , None , None , None )
800
822
if from_subgraph :
0 commit comments