File tree Expand file tree Collapse file tree 3 files changed +11
-14
lines changed
torch/_deploy/_runtime/tensorrt
tests/_test_utils/onnx_quantization Expand file tree Collapse file tree 3 files changed +11
-14
lines changed Original file line number Diff line number Diff line change @@ -527,11 +527,13 @@ def _get_successive_consumers(
527
527
raise ValueError (f"Invalid consumer for { node .name } " )
528
528
529
529
quantized_node = tensor_consumers .get (dq_node .output [0 ], [None ])[0 ]
530
- if quantized_node .op_type == "Cast" :
531
- quantized_node = tensor_consumers .get (quantized_node .output [0 ], [None ])[0 ]
532
-
533
530
if not quantized_node :
534
531
raise ValueError (f"No consumer found for { dq_node .name } " )
532
+ if quantized_node .op_type == "Cast" :
533
+ next_node = tensor_consumers .get (quantized_node .output [0 ], [None ])[0 ]
534
+ if not next_node :
535
+ raise ValueError (f"No consumer found after Cast for { quantized_node .name } " )
536
+ quantized_node = next_node
535
537
536
538
return dq_node , quantized_node
537
539
Original file line number Diff line number Diff line change @@ -131,11 +131,6 @@ def get_output_shapes(
131
131
return output_shapes
132
132
133
133
134
- def validate_precision (precision : str ) -> bool :
135
- """Returns whether an input precision is in supported set."""
136
- return precision in [TRTMode .FLOAT32 , TRTMode .FLOAT16 , TRTMode .INT8 ]
137
-
138
-
139
134
def calib_data_generator (onnx_bytes : bytes , input_tensors : list [np .ndarray ]):
140
135
"""The calibation data generator that yields calibration feed_dict to tensorrt."""
141
136
input_names = get_onnx_input_names (onnx .load_from_string (onnx_bytes ))
Original file line number Diff line number Diff line change @@ -20,11 +20,11 @@ def _assert_nodes_are_quantized(nodes):
20
20
for node in nodes :
21
21
for inp_idx , inp in enumerate (node .inputs ):
22
22
if isinstance (inp , gs .Variable ):
23
- qnode = node
24
- # After quantization, the quantized node can be casted
25
- if qnode . i ( inp_idx ) .op == "Cast" :
26
- qnode = qnode .i (inp_idx )
27
- assert qnode . i ( inp_idx ) .op == "DequantizeLinear" , (
28
- f"Input '{ inp .name } ' of node '{ qnode .name } ' is not quantized but should be!"
23
+ producer = node . i ( inp_idx )
24
+ # Quantized path may include a Cast right after DQ
25
+ if producer and producer .op == "Cast" :
26
+ producer = producer .i (0 )
27
+ assert producer and producer .op == "DequantizeLinear" , (
28
+ f"Input '{ inp .name } ' of node '{ node .name } ' is not quantized but should be!"
29
29
)
30
30
return True
You can’t perform that action at this time.
0 commit comments