Skip to content

Commit ed58324

Browse files
authored
[NVBUG: 5536887] Check output name map while creating inputs for per node calibration (#509)
## What does this PR do? **Type of change:** Bug fix **Overview:** - input to a node could be model output - Hence we also check the `output_name_map` while creating the single node inputs ## Testing ```python python -m modelopt.onnx.quantization --onnx_path=aurora_model_batch_5.onnx --calibrate_per_node ``` ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes - **Did you write any new necessary tests?**: Yes - **Did you add or update any necessary documentation?**: No - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: No --------- Signed-off-by: ajrasane <[email protected]>
1 parent a94f95d commit ed58324

File tree

1 file changed

+24
-36
lines changed

1 file changed

+24
-36
lines changed

modelopt/onnx/quantization/ort_patching.py

Lines changed: 24 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1248,15 +1248,12 @@ def add_reduce_min_max(tensor_name):
12481248
value_info_found = False
12491249
# If a node input is not an initializer, set it as a model input
12501250
if not is_input_initializer:
1251-
if input_name in value_info_name_map:
1252-
single_node_model_inputs.append(value_info_name_map[input_name])
1253-
single_node_model_input_names.append(input_name)
1254-
value_info_found = True
1255-
1256-
if not value_info_found and input_name in input_name_map:
1257-
single_node_model_inputs.append(input_name_map[input_name])
1258-
single_node_model_input_names.append(input_name)
1259-
value_info_found = True
1251+
for name_map in [value_info_name_map, input_name_map, output_name_map]:
1252+
if input_name in name_map:
1253+
single_node_model_inputs.append(name_map[input_name])
1254+
single_node_model_input_names.append(input_name)
1255+
value_info_found = True
1256+
break
12601257

12611258
if not value_info_found:
12621259
raise ValueError(
@@ -1266,15 +1263,12 @@ def add_reduce_min_max(tensor_name):
12661263
# Process each output for node
12671264
for output_name in node.output:
12681265
value_info_found = False
1269-
if output_name in value_info_name_map:
1270-
single_node_model_outputs.append(value_info_name_map[output_name])
1271-
single_node_model_output_names.append(output_name)
1272-
value_info_found = True
1273-
1274-
if not value_info_found and output_name in output_name_map:
1275-
single_node_model_outputs.append(output_name_map[output_name])
1276-
single_node_model_output_names.append(output_name)
1277-
value_info_found = True
1266+
for name_map in [value_info_name_map, output_name_map]:
1267+
if output_name in name_map:
1268+
single_node_model_outputs.append(name_map[output_name])
1269+
single_node_model_output_names.append(output_name)
1270+
value_info_found = True
1271+
break
12781272

12791273
if not value_info_found:
12801274
raise ValueError(
@@ -1363,15 +1357,12 @@ def _augment_graph_histogram_calibrater_single_node_calibration(calibrater):
13631357
# If a node input is not an initializer, set it as a model input
13641358
if not is_input_initializer:
13651359
value_info_found = False
1366-
if input_name in value_info_name_map:
1367-
single_node_model_inputs.append(value_info_name_map[input_name])
1368-
single_node_model_input_names.append(input_name)
1369-
value_info_found = True
1370-
1371-
if not value_info_found and input_name in input_name_map:
1372-
single_node_model_inputs.append(input_name_map[input_name])
1373-
single_node_model_input_names.append(input_name)
1374-
value_info_found = True
1360+
for name_map in [value_info_name_map, input_name_map, output_name_map]:
1361+
if input_name in name_map:
1362+
single_node_model_inputs.append(name_map[input_name])
1363+
single_node_model_input_names.append(input_name)
1364+
value_info_found = True
1365+
break
13751366

13761367
if not value_info_found:
13771368
raise ValueError(
@@ -1381,15 +1372,12 @@ def _augment_graph_histogram_calibrater_single_node_calibration(calibrater):
13811372
# Process each output for node
13821373
for output_name in node.output:
13831374
value_info_found = False
1384-
if output_name in value_info_name_map:
1385-
single_node_model_outputs.append(value_info_name_map[output_name])
1386-
single_node_model_output_names.append(output_name)
1387-
value_info_found = True
1388-
1389-
if not value_info_found and output_name in output_name_map:
1390-
single_node_model_outputs.append(output_name_map[output_name])
1391-
single_node_model_output_names.append(output_name)
1392-
value_info_found = True
1375+
for name_map in [value_info_name_map, output_name_map]:
1376+
if output_name in name_map:
1377+
single_node_model_outputs.append(name_map[output_name])
1378+
single_node_model_output_names.append(output_name)
1379+
value_info_found = True
1380+
break
13931381

13941382
if not value_info_found:
13951383
raise ValueError(

0 commit comments

Comments
 (0)