Skip to content

Commit 6cf8d12

Browse files
TedThemistokleousfs-eire
authored andcommitted
[MIGraphX EP] Link FP4 types between OnnxRT and MIGraphX APIs (#26231)
Do this so that MIGraphX can take in fp4 types from input/output tensors and then use that to perform an inference via the MIGraphX API. ### Description <!-- Describe your changes. --> Mirroed changes going into ROCm 7.1 build. Cherry -picked mainline OnnxRT changes to get fp4 tensor support before adding this ontop. Moving this to mainline OnnxRt to enable the MIGraphX EP to allow for fp4 input/output tensors ROCm#176 ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Add fp4 support to MIGraphX EP
1 parent e23ea25 commit 6cf8d12

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ static bool IsTypeSupported(const NodeArg* node_arg) {
268268
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16:
269269
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BFLOAT16:
270270
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT:
271+
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT4E2M1:
271272
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FN:
272273
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FNUZ:
273274
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E5M2:
@@ -318,6 +319,9 @@ static bool getMIGraphXType(ONNXTensorElementDataType type,
318319
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ:
319320
mgx_type = migraphx_shape_fp8e5m2fnuz_type;
320321
break;
322+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT4E2M1:
323+
mgx_type = migraphx_shape_fp4x2_type;
324+
break;
321325
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4:
322326
mgx_type = migraphx_shape_int8_type;
323327
break;

0 commit comments

Comments
 (0)