|
31 | 31 | #include "core/graph/constants.h" |
32 | 32 | #include "core/graph/graph.h" |
33 | 33 | #include "core/graph/model_editor_api_types.h" |
| 34 | +#include "core/graph/ep_api_types.h" |
34 | 35 | #include "core/providers/get_execution_providers.h" |
35 | 36 | #include "core/session/abi_session_options_impl.h" |
36 | 37 | #include "core/session/allocator_adapters.h" |
@@ -2858,6 +2859,82 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetImplicitInputs, _In_ const OrtNode* node, |
2858 | 2859 | API_IMPL_END |
2859 | 2860 | } |
2860 | 2861 |
|
| 2862 | +ORT_API_STATUS_IMPL(OrtApis::Node_GetAttributes, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** attributes) { |
| 2863 | + API_IMPL_BEGIN |
| 2864 | + if (attributes == nullptr) { |
| 2865 | + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'attributes' argument is NULL"); |
| 2866 | + } |
| 2867 | + |
| 2868 | + std::unique_ptr<OrtArrayOfConstObjects> array; |
| 2869 | + ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetAttributes(array)); |
| 2870 | + |
| 2871 | + *attributes = array.release(); |
| 2872 | + return nullptr; |
| 2873 | + API_IMPL_END |
| 2874 | +} |
| 2875 | + |
| 2876 | +ORT_API_STATUS_IMPL(OrtApis::Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name, _Outptr_ const OrtOpAttr** attribute) { |
| 2877 | + API_IMPL_BEGIN |
| 2878 | + if (attribute == nullptr) { |
| 2879 | + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'attribute' argument is NULL"); |
| 2880 | + } |
| 2881 | + |
| 2882 | + const EpNode* ep_node = EpNode::ToInternal(node); |
| 2883 | + if (ep_node == nullptr) { |
| 2884 | + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "node is a ModelEditorNode which doesn't support Node_GetAttributeByName."); |
| 2885 | + } |
| 2886 | + |
| 2887 | + *attribute = ep_node->GetAttribute(attribute_name); |
| 2888 | + |
| 2889 | + if (*attribute) { |
| 2890 | + return nullptr; |
| 2891 | + } else { |
| 2892 | + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Attribute does not exist."); |
| 2893 | + } |
| 2894 | + |
| 2895 | + API_IMPL_END |
| 2896 | +} |
| 2897 | + |
| 2898 | +ORT_API_STATUS_IMPL(OrtApis::OpAttr_GetType, _In_ const OrtOpAttr* attribute, _Out_ OrtOpAttrType* type) { |
| 2899 | + API_IMPL_BEGIN |
| 2900 | + const auto attr = attribute->attr_proto; |
| 2901 | + auto onnx_attr_type = attribute->attr_proto.type(); |
| 2902 | + switch (onnx_attr_type) { |
| 2903 | + case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_UNDEFINED: { |
| 2904 | + *type = OrtOpAttrType::ORT_OP_ATTR_UNDEFINED; |
| 2905 | + break; |
| 2906 | + } |
| 2907 | + case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_INT: { |
| 2908 | + *type = OrtOpAttrType::ORT_OP_ATTR_INT; |
| 2909 | + break; |
| 2910 | + } |
| 2911 | + case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_INTS: { |
| 2912 | + *type = OrtOpAttrType::ORT_OP_ATTR_INTS; |
| 2913 | + break; |
| 2914 | + } |
| 2915 | + case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_FLOAT: { |
| 2916 | + *type = OrtOpAttrType::ORT_OP_ATTR_FLOAT; |
| 2917 | + break; |
| 2918 | + } |
| 2919 | + case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_FLOATS: { |
| 2920 | + *type = OrtOpAttrType::ORT_OP_ATTR_FLOATS; |
| 2921 | + break; |
| 2922 | + } |
| 2923 | + case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_STRING: { |
| 2924 | + *type = OrtOpAttrType::ORT_OP_ATTR_STRING; |
| 2925 | + break; |
| 2926 | + } |
| 2927 | + case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_STRINGS: { |
| 2928 | + *type = OrtOpAttrType::ORT_OP_ATTR_STRINGS; |
| 2929 | + break; |
| 2930 | + } |
| 2931 | + default: |
| 2932 | + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Unexpected attribute type."); |
| 2933 | + } |
| 2934 | + return nullptr; |
| 2935 | + API_IMPL_END |
| 2936 | +} |
| 2937 | + |
2861 | 2938 | ORT_API_STATUS_IMPL(OrtApis::Node_GetSubgraphs, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** subgraphs) { |
2862 | 2939 | API_IMPL_BEGIN |
2863 | 2940 | if (subgraphs == nullptr) { |
@@ -3558,6 +3635,9 @@ static constexpr OrtApi ort_api_1_to_23 = { |
3558 | 3635 | &OrtApis::Node_GetInputs, |
3559 | 3636 | &OrtApis::Node_GetOutputs, |
3560 | 3637 | &OrtApis::Node_GetImplicitInputs, |
| 3638 | + &OrtApis::Node_GetAttributes, |
| 3639 | + &OrtApis::Node_GetAttributeByName, |
| 3640 | + &OrtApis::OpAttr_GetType, |
3561 | 3641 | &OrtApis::Node_GetSubgraphs, |
3562 | 3642 | &OrtApis::Node_GetParentGraph, |
3563 | 3643 |
|
|
0 commit comments