Skip to content

Commit dadd0c4

Browse files
[MIGraphX EP] Fix MIGraphX mixed precision run input parameters (#20982)
See #20643 ### Description Changes order of how we perform quantization to better support mixed precision and fixes a bug found with parameters of inputs for int8 quantization not being correctly handled. We now perform int8 quantization first on a full precision input model, before then quantizing the model to fp16 for remain ops that aren't quantized. The former case was causing us to use a low precision input which could cause larger values to be inserted than intended to the model when int8 quantization is perform. The symptom of this was a failure during quantization steps. Similar to the above input parameters were being uninitialized and resulting in similar failure during int8 quantization. GPU faults were intermittent but present as using uninitialized memory created undefined behavior when we started testing more complex models during mixed precision. ### Motivation and Context In some cases we've seen random data and/or invalid values entering into compiled onnx graphs. This is due to input parameters to the MIGraphX Graph not being set correctly when mixed precision (int8 + fp16) is used and ordering of quantization steps is causes a lower precision model to be used to perform int8 quantization. In most cases the failure is silent/intermittent. In some cases we've observed gpu faults due to out of bounds values being set. This change is required as a large input parameter to the MIGraphX graph is initialized to a large random value, and the next operator is using that for indexing, we get undefined behavior and a GPU fault.
1 parent 809cb26 commit dadd0c4

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1172,7 +1172,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
11721172
}
11731173

11741174
std::vector<std::string> input_names, output_names;
1175-
no_input_shape = no_input_shape or get_input_output_names(graph_body_viewer, input_names, output_names);
1175+
no_input_shape = no_input_shape || get_input_output_names(graph_body_viewer, input_names, output_names);
11761176

11771177
// by parsing the model_proto, create a program corresponding to
11781178
// the input fused_node
@@ -1356,7 +1356,6 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
13561356
quant_params.add(cal_key.c_str(), migraphx::argument(cal_val_shape, static_cast<void*>(std::move(&cal_val))));
13571357
}
13581358
quant_opts.add_calibration_data(quant_params);
1359-
13601359
// specify thing we want to int8 quantize
13611360
quant_opts.add_op_name("convolution");
13621361
quant_opts.add_op_name("dot");

0 commit comments

Comments
 (0)