Skip to content

Commit 36fc8c8

Browse files
TedThemistokleousTed Themistokleous
andauthored
[MIGraphX EP] Add migx ep fp8 support and int4 weights (microsoft#23534)
* Add fp8 and int4 types in supported list for Onnxruntime EP * Add support for int4 inputs Map things to int8 right now as we don't explicitly set an int4 input type and pack/unpack int4 operands * Add flag to allow for fp8 quantization through Onnxruntime API * Add fp8 quantization to the compile stage of the MIGraphX EP Mirror the same calibration code we use for int8 and just change which quantize we call through the MIGraphx API * cleanup logging * Cleanup and encapsulate quantization / compile functions - Add additional flags for fp8 thats shared for int8 - Add lockout warning message when int8/fp8 used at the same time * Run lintrunner pass * Fix session options inputs + add better logging. Previous runs using session options failed as we were missing pulling in inputs from the python interface. This plus additional logging allowed me to track what options were invoked via env and what were added during the start of an inference session * Fix naming for save/load path varibles to be consistent with enable. * Print only env variables that are set as warnings need this so the user knows there's any of the environment variables running in the background to ensure proper consistently between runs. --------- ### Description <!-- Describe your changes. --> Changes to cleanup the MIGraphX EP quantization code as well as adding fp8 quantization support along with int4 support. Cleanup changes handle a few instances of issues seen with the python interface when taking in provider options ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Required as we fix ignored flags when using provider_options for the MIGraphX EP Adding fp8 quantization through the MIGraphX API Adding int4 weight support for packed int4 weights for MIGraphX inference --------- Co-authored-by: Ted Themistokleous <[email protected]>
1 parent 665922d commit 36fc8c8

File tree

8 files changed

+244
-114
lines changed

8 files changed

+244
-114
lines changed

include/onnxruntime/core/session/onnxruntime_c_api.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,7 @@ typedef struct OrtTensorRTProviderOptions {
673673
typedef struct OrtMIGraphXProviderOptions {
674674
int device_id; // hip device id.
675675
int migraphx_fp16_enable; // MIGraphX FP16 precision. Default 0 = false, nonzero = true
676+
int migraphx_fp8_enable; // MIGraphX FP8 precision. Default 0 = false, nonzero = true
676677
int migraphx_int8_enable; // MIGraphX INT8 precision. Default 0 = false, nonzero = true
677678
int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, noznero = true
678679
const char* migraphx_int8_calibration_table_name; // MIGraphx INT8 calibration table name

onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

Lines changed: 207 additions & 104 deletions
Large diffs are not rendered by default.

onnxruntime/core/providers/migraphx/migraphx_execution_provider.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,16 @@ namespace onnxruntime {
1717

1818
namespace migraphx_env_vars {
1919
static const char kFP16Enable[] = "ORT_MIGRAPHX_FP16_ENABLE";
20+
static const char kFP8Enable[] = "ORT_MIGRAPHX_FP8_ENABLE";
2021
static const char kINT8Enable[] = "ORT_MIGRAPHX_INT8_ENABLE";
2122
static const char dumpModelOps[] = "ORT_MIGRAPHX_DUMP_MODEL_OPS";
2223
static const char kINT8CalibrationTableName[] = "ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME";
2324
static const char kCachePath[] = "ORT_MIGRAPHX_CACHE_PATH";
2425
static const char kINT8UseNativeMIGraphXCalibrationTable[] = "ORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE";
2526
static const char kSaveCompiledModel[] = "ORT_MIGRAPHX_SAVE_COMPILED_MODEL";
26-
static const char kSavedModelPath[] = "ORT_MIGRAPHX_SAVE_COMPILE_PATH";
27+
static const char kSavedModelPath[] = "ORT_MIGRAPHX_SAVE_COMPILED_PATH";
2728
static const char kLoadCompiledModel[] = "ORT_MIGRAPHX_LOAD_COMPILED_MODEL";
28-
static const char kLoadModelPath[] = "ORT_MIGRAPHX_LOAD_COMPILE_PATH";
29+
static const char kLoadModelPath[] = "ORT_MIGRAPHX_LOAD_COMPILED_PATH";
2930
static const char kExhaustiveTune[] = "ORT_MIGRAPHX_EXHAUSTIVE_TUNE";
3031

3132
}; // namespace migraphx_env_vars
@@ -43,6 +44,7 @@ struct MIGraphXFuncState {
4344
std::mutex* mgx_mu_ptr = nullptr;
4445
bool no_input_shape = false;
4546
bool fp16_enable = false;
47+
bool fp8_enable = false;
4648
bool int8_enable = false;
4749
bool int8_calibration_cache_available = false;
4850
std::unordered_map<std::string, float> dynamic_range_map;
@@ -60,6 +62,10 @@ class MIGraphXExecutionProvider : public IExecutionProvider {
6062
explicit MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info);
6163
~MIGraphXExecutionProvider();
6264

65+
void get_flags_from_session_info(const MIGraphXExecutionProviderInfo& info);
66+
void get_flags_from_env();
67+
void print_migraphx_ep_flags();
68+
6369
Status Sync() const override;
6470

6571
Status OnRunStart(const onnxruntime::RunOptions& run_options) override;
@@ -91,6 +97,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider {
9197
private:
9298
MIGraphXExecutionProviderInfo info_;
9399
bool fp16_enable_ = false;
100+
bool fp8_enable_ = false;
94101
bool int8_enable_ = false;
95102
std::string int8_calibration_cache_name_;
96103
bool int8_calibration_cache_available_ = false;

onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ namespace migraphx {
1414
namespace provider_option_names {
1515
constexpr const char* kDeviceId = "device_id";
1616
constexpr const char* kFp16Enable = "trt_fp16_enable";
17+
constexpr const char* kFp8Enable = "migx_fp8_enable";
1718
constexpr const char* kInt8Enable = "migx_int8_enable";
1819
constexpr const char* kInt8CalibTable = "migx_int8_calibration_table_name";
1920
constexpr const char* kInt8UseNativeCalibTable = "migx_int8_use_native_calibration_table";
@@ -43,6 +44,7 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions
4344
return Status::OK();
4445
})
4546
.AddAssignmentToReference(migraphx::provider_option_names::kFp16Enable, info.fp16_enable)
47+
.AddAssignmentToReference(migraphx::provider_option_names::kFp8Enable, info.fp8_enable)
4648
.AddAssignmentToReference(migraphx::provider_option_names::kInt8Enable, info.int8_enable)
4749
.AddAssignmentToReference(migraphx::provider_option_names::kSaveCompiledModel, info.save_compiled_model)
4850
.AddAssignmentToReference(migraphx::provider_option_names::kLoadCompiledModel, info.load_compiled_model)
@@ -56,6 +58,7 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const MIGraphXE
5658
const ProviderOptions options{
5759
{migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
5860
{migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)},
61+
{migraphx::provider_option_names::kFp8Enable, MakeStringWithClassicLocale(info.fp8_enable)},
5962
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)},
6063
{migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.save_compiled_model)},
6164
{migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.load_compiled_model)},
@@ -68,6 +71,7 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const OrtMIGrap
6871
const ProviderOptions options{
6972
{migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
7073
{migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.migraphx_fp16_enable)},
74+
{migraphx::provider_option_names::kFp8Enable, MakeStringWithClassicLocale(info.migraphx_fp8_enable)},
7175
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)},
7276
{migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.migraphx_save_compiled_model)},
7377
{migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.migraphx_load_compiled_model)},

onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ struct MIGraphXExecutionProviderInfo {
1616
std::string target_device;
1717
OrtDevice::DeviceId device_id{0};
1818
bool fp16_enable{false};
19+
bool fp8_enable{false};
1920
bool int8_enable{false};
2021
std::string int8_calibration_table_name{""};
2122
bool int8_use_native_calibration_table{false};

onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ struct MIGraphX_Provider : Provider {
6060
info.device_id = static_cast<OrtDevice::DeviceId>(options.device_id);
6161
info.target_device = "gpu";
6262
info.fp16_enable = options.migraphx_fp16_enable;
63+
info.fp8_enable = options.migraphx_fp8_enable;
6364
info.exhaustive_tune = options.migraphx_exhaustive_tune;
6465
info.int8_enable = options.migraphx_int8_enable;
6566
info.int8_calibration_table_name = "";
@@ -85,6 +86,7 @@ struct MIGraphX_Provider : Provider {
8586
auto& migx_options = *reinterpret_cast<OrtMIGraphXProviderOptions*>(provider_options);
8687
migx_options.device_id = internal_options.device_id;
8788
migx_options.migraphx_fp16_enable = internal_options.fp16_enable;
89+
migx_options.migraphx_fp8_enable = internal_options.fp8_enable;
8890
migx_options.migraphx_int8_enable = internal_options.int8_enable;
8991
migx_options.migraphx_exhaustive_tune = internal_options.exhaustive_tune;
9092

onnxruntime/python/onnxruntime_pybind_state.cc

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -946,6 +946,7 @@ static std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory
946946
0,
947947
0,
948948
0,
949+
0,
949950
nullptr,
950951
1,
951952
"./compiled_model.mxr",
@@ -966,7 +967,17 @@ static std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory
966967
params.migraphx_fp16_enable = false;
967968
} else {
968969
ORT_THROW(
969-
"[ERROR] [MIGraphX] The value for the key 'trt_fp16_enable' should be"
970+
"[ERROR] [MIGraphX] The value for the key 'migraphx_fp16_enable' should be"
971+
" 'True' or 'False'. Default value is 'False'.\n");
972+
}
973+
} else if (option.first == "migraphx_fp8_enable") {
974+
if (option.second == "True" || option.second == "true") {
975+
params.migraphx_fp8_enable = true;
976+
} else if (option.second == "False" || option.second == "false") {
977+
params.migraphx_fp8_enable = false;
978+
} else {
979+
ORT_THROW(
980+
"[ERROR] [MIGraphX] The value for the key 'migraphx_fp8_enable' should be"
970981
" 'True' or 'False'. Default value is 'False'.\n");
971982
}
972983
} else if (option.first == "migraphx_int8_enable") {
@@ -976,7 +987,7 @@ static std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory
976987
params.migraphx_int8_enable = false;
977988
} else {
978989
ORT_THROW(
979-
"[ERROR] [MIGraphX] The value for the key 'migx_int8_enable' should be"
990+
"[ERROR] [MIGraphX] The value for the key 'migraphx_int8_enable' should be"
980991
" 'True' or 'False'. Default value is 'False'.\n");
981992
}
982993
} else if (option.first == "migraphx_int8_calibration_table_name") {
@@ -985,7 +996,7 @@ static std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory
985996
params.migraphx_int8_calibration_table_name = calibration_table.c_str();
986997
} else {
987998
ORT_THROW(
988-
"[ERROR] [MIGraphX] The value for the key 'migx_int8_calibration_table_name' should be a "
999+
"[ERROR] [MIGraphX] The value for the key 'migraphx_int8_calibration_table_name' should be a "
9891000
"file name i.e. 'cal_table'.\n");
9901001
}
9911002
} else if (option.first == "migraphx_use_native_calibration_table") {
@@ -995,7 +1006,7 @@ static std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory
9951006
params.migraphx_use_native_calibration_table = false;
9961007
} else {
9971008
ORT_THROW(
998-
"[ERROR] [MIGraphX] The value for the key 'migx_int8_use_native_calibration_table' should be"
1009+
"[ERROR] [MIGraphX] The value for the key 'migraphx_use_native_calibration_table' should be"
9991010
" 'True' or 'False'. Default value is 'False'.\n");
10001011
}
10011012
} else if (option.first == "migraphx_save_compiled_model") {
@@ -1005,7 +1016,7 @@ static std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory
10051016
params.migraphx_fp16_enable = false;
10061017
} else {
10071018
ORT_THROW(
1008-
"[ERROR] [MIGraphX] The value for the key 'migx_save_compiled_model' should be"
1019+
"[ERROR] [MIGraphX] The value for the key 'migraphx_save_compiled_model' should be"
10091020
" 'True' or 'False'. Default value is 'False'.\n");
10101021
}
10111022
} else if (option.first == "migraphx_save_model_path") {
@@ -1014,7 +1025,7 @@ static std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory
10141025
params.migraphx_save_model_path = save_model_path.c_str();
10151026
} else {
10161027
ORT_THROW(
1017-
"[ERROR] [MIGraphX] The value for the key 'migx_save_model_name' should be a "
1028+
"[ERROR] [MIGraphX] The value for the key 'migraphx_save_model_name' should be a "
10181029
"file name i.e. 'compiled_model.mxr'.\n");
10191030
}
10201031
} else if (option.first == "migraphx_load_compiled_model") {
@@ -1024,7 +1035,7 @@ static std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory
10241035
params.migraphx_fp16_enable = false;
10251036
} else {
10261037
ORT_THROW(
1027-
"[ERROR] [MIGraphX] The value for the key 'migx_load_compiled_model' should be"
1038+
"[ERROR] [MIGraphX] The value for the key 'migraphx_load_compiled_model' should be"
10281039
" 'True' or 'False'. Default value is 'False'.\n");
10291040
}
10301041
} else if (option.first == "migraphx_load_model_path") {
@@ -1033,7 +1044,7 @@ static std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory
10331044
params.migraphx_load_model_path = load_model_path.c_str();
10341045
} else {
10351046
ORT_THROW(
1036-
"[ERROR] [MIGraphX] The value for the key 'migx_load_model_name' should be a "
1047+
"[ERROR] [MIGraphX] The value for the key 'migraphx_load_model_name' should be a "
10371048
"file name i.e. 'compiled_model.mxr'.\n");
10381049
}
10391050
} else if (option.first == "migraphx_exhaustive_tune") {

onnxruntime/test/util/default_providers.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ std::unique_ptr<IExecutionProvider> DefaultMIGraphXExecutionProvider() {
8585
0,
8686
0,
8787
0,
88+
0,
8889
nullptr,
8990
1,
9091
"./compiled_model.mxr",

0 commit comments

Comments
 (0)