Skip to content

Commit 5ddd34e

Browse files
authored
Add Node_GetAttributes C API for EP ABI (microsoft#25143)
This PRs adds additional Node_GetAttributes C API for EP ABI use. It's based on microsoft#24887
1 parent 93ee7bf commit 5ddd34e

File tree

8 files changed

+245
-0
lines changed

8 files changed

+245
-0
lines changed

include/onnxruntime/core/session/onnxruntime_c_api.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,7 @@ typedef OrtStatus*(ORT_API_CALL* EpSelectionDelegate)(_In_ const OrtEpDevice** e
499499
typedef enum OrtTypeTag {
500500
ORT_TYPE_TAG_Void,
501501
ORT_TYPE_TAG_OrtValueInfo,
502+
ORT_TYPE_TAG_OrtOpAttr,
502503
ORT_TYPE_TAG_OrtNode,
503504
ORT_TYPE_TAG_OrtGraph,
504505
} OrtTypeTag;
@@ -5881,6 +5882,41 @@ struct OrtApi {
58815882
*/
58825883
ORT_API2_STATUS(Node_GetImplicitInputs, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** implicit_inputs);
58835884

5885+
/** \brief Returns a node's attributes as OrtOpAttr instances.
5886+
*
5887+
* \param[in] node The OrtNode instance.
5888+
* \param[out] attributes Output parameter set to the OrtArrayOfConstObjects instance containing the node's attributes
5889+
* as OrtOpAttr instances. Must be released by calling ReleaseArrayOfConstObjects.
5890+
*
5891+
* \snippet{doc} snippets.dox OrtStatus Return Value
5892+
*
5893+
* \since Version 1.23.
5894+
*/
5895+
ORT_API2_STATUS(Node_GetAttributes, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** attributes);
5896+
5897+
/** \brief Gets the OrtNode's attribute as OrtOpAttr by name.
5898+
*
5899+
* \param[in] node The OrtNode instance.
5900+
* \param[in] attribute_name The name of the attribute
5901+
* \param[out] attribute Output the attribute if its name matches 'attribute_name', otherwise output nullptr.
5902+
*
5903+
* \snippet{doc} snippets.dox OrtStatus Return Value
5904+
*
5905+
* \since Version 1.23.
5906+
*/
5907+
ORT_API2_STATUS(Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name, _Outptr_ const OrtOpAttr** attribute);
5908+
5909+
/** \brief Get the attribute type as OrtOpAttrType from an OrtOpAttr.
5910+
*
5911+
* \param[in] attribute The OrtOpAttr instance.
5912+
* \param[out] type Output the attribute type as OrtOpAttrType.
5913+
*
5914+
* \snippet{doc} snippets.dox OrtStatus Return Value
5915+
*
5916+
* \since Version 1.23.
5917+
*/
5918+
ORT_API2_STATUS(OpAttr_GetType, _In_ const OrtOpAttr* attribute, _Out_ OrtOpAttrType* type);
5919+
58845920
/** \brief Get the subgraphs, as OrtGraph instances, contained by the given node.
58855921
*
58865922
* Certain operator types (e.g., If and Loop) contain nested subgraphs.

onnxruntime/core/graph/abi_graph_types.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,13 @@ struct OrtNode {
202202
/// <returns>A status indicating success or an error.</returns>
203203
virtual onnxruntime::Status GetImplicitInputs(std::unique_ptr<OrtArrayOfConstObjects>& implicit_inputs) const = 0;
204204

205+
/// <summary>
206+
/// Gets the node's attributes as an array of OrtOpAttr elements wrapped in an OrtArrayOfConstObjects.
207+
/// </summary>
208+
/// <param name="attrs">Output parameter set to the node's attributes.</param>
209+
/// <returns>A status indicating success or an error.</returns>
210+
virtual onnxruntime::Status GetAttributes(std::unique_ptr<OrtArrayOfConstObjects>& attrs) const = 0;
211+
205212
/// <summary>
206213
/// Gets the node's subgraphs (e.g., subgraphs contained by an If or Loop node).
207214
/// </summary>

onnxruntime/core/graph/ep_api_types.cc

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,20 @@ Status EpNode::Create(const Node& node, const EpGraph* ep_graph,
9191
ConvertNodeArgsToValueInfos(ep_graph, value_infos_map, node_inputs, ep_node_inputs);
9292
ConvertNodeArgsToValueInfos(ep_graph, value_infos_map, node_outputs, ep_node_outputs);
9393

94+
const auto& node_attrs = node.GetAttributes();
95+
std::unordered_map<std::string, std::unique_ptr<ONNX_NAMESPACE::AttributeProto>> ep_node_attributes_map;
96+
std::vector<OrtOpAttr*> ep_node_attributes;
97+
98+
if (node_attrs.size() > 0) {
99+
ep_node_attributes.reserve(node_attrs.size());
100+
101+
for (const auto& item : node_attrs) {
102+
auto attr = std::make_unique<ONNX_NAMESPACE::AttributeProto>(item.second); // Copy AttributeProto and owned by this EpNode object.
103+
ep_node_attributes.push_back(reinterpret_cast<OrtOpAttr*>(attr.get()));
104+
ep_node_attributes_map.emplace(item.first, std::move(attr));
105+
}
106+
}
107+
94108
std::vector<SubgraphState> ep_node_subgraphs;
95109
std::vector<EpValueInfo*> ep_node_implicit_inputs;
96110

@@ -115,6 +129,8 @@ Status EpNode::Create(const Node& node, const EpGraph* ep_graph,
115129

116130
ep_node->inputs_ = std::move(ep_node_inputs);
117131
ep_node->outputs_ = std::move(ep_node_outputs);
132+
ep_node->attributes_map_ = std::move(ep_node_attributes_map);
133+
ep_node->attributes_ = std::move(ep_node_attributes);
118134
ep_node->implicit_inputs_ = std::move(ep_node_implicit_inputs);
119135
ep_node->subgraphs_ = std::move(ep_node_subgraphs);
120136

@@ -169,6 +185,17 @@ Status EpNode::GetImplicitInputs(std::unique_ptr<OrtArrayOfConstObjects>& result
169185
return Status::OK();
170186
}
171187

188+
Status EpNode::GetAttributes(std::unique_ptr<OrtArrayOfConstObjects>& result) const {
189+
result = std::make_unique<OrtArrayOfConstObjects>(ORT_TYPE_TAG_OrtOpAttr);
190+
result->storage.reserve(attributes_.size());
191+
192+
for (const OrtOpAttr* attr : attributes_) {
193+
result->storage.push_back(attr);
194+
}
195+
196+
return Status::OK();
197+
}
198+
172199
Status EpNode::GetSubgraphs(std::unique_ptr<OrtArrayOfConstObjects>& result) const {
173200
result = std::make_unique<OrtArrayOfConstObjects>(ORT_TYPE_TAG_OrtGraph);
174201
result->storage.reserve(subgraphs_.size());
@@ -197,6 +224,15 @@ gsl::span<const EpValueInfo* const> EpNode::GetOutputsSpan() const {
197224
return outputs_;
198225
}
199226

227+
const OrtOpAttr* EpNode::GetAttribute(const std::string& name) const {
228+
auto iter = attributes_map_.find(name);
229+
if (iter == attributes_map_.end()) {
230+
return nullptr;
231+
} else {
232+
return reinterpret_cast<const OrtOpAttr*>(iter->second.get());
233+
}
234+
}
235+
200236
//
201237
// EpValueInfo
202238
//

onnxruntime/core/graph/ep_api_types.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,9 @@ struct EpNode : public OrtNode {
164164
// Gets the node's implicit inputs as OrtValueInfo instances wrapped in an OrtArrayOfConstObjects.
165165
Status GetImplicitInputs(std::unique_ptr<OrtArrayOfConstObjects>& inputs) const override;
166166

167+
// Gets the node's attributes as OrtOpAttr instances wrapped in an OrtArrayOfConstObjects.
168+
Status GetAttributes(std::unique_ptr<OrtArrayOfConstObjects>& attrs) const override;
169+
167170
// Gets the subgraphs contained by this node.
168171
Status GetSubgraphs(std::unique_ptr<OrtArrayOfConstObjects>& subgraphs) const override;
169172

@@ -186,6 +189,9 @@ struct EpNode : public OrtNode {
186189
// Helper that returns this node's outputs as a span of EpValueInfo pointers.
187190
gsl::span<const EpValueInfo* const> GetOutputsSpan() const;
188191

192+
// Helper that gets the node's attributes by name.
193+
const OrtOpAttr* GetAttribute(const std::string& name) const;
194+
189195
private:
190196
// Back pointer to containing graph. Useful when traversing through nested subgraphs.
191197
// Will be nullptr if the EpNode was created without an owning graph.
@@ -196,6 +202,9 @@ struct EpNode : public OrtNode {
196202
InlinedVector<EpValueInfo*> inputs_;
197203
InlinedVector<EpValueInfo*> outputs_;
198204

205+
std::unordered_map<std::string, std::unique_ptr<ONNX_NAMESPACE::AttributeProto>> attributes_map_;
206+
std::vector<OrtOpAttr*> attributes_;
207+
199208
std::vector<EpValueInfo*> implicit_inputs_;
200209
std::vector<SubgraphState> subgraphs_;
201210
};

onnxruntime/core/graph/model_editor_api_types.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,11 @@ struct ModelEditorNode : public OrtNode {
114114
"OrtModelEditorApi does not support getting the implicit inputs for OrtNode");
115115
}
116116

117+
Status GetAttributes(std::unique_ptr<OrtArrayOfConstObjects>& /*attrs*/) const override {
118+
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
119+
"OrtModelEditorApi does not support getting attribute OrtOpAttr for OrtNode");
120+
}
121+
117122
Status GetSubgraphs(std::unique_ptr<OrtArrayOfConstObjects>& /*subgraphs*/) const override {
118123
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
119124
"OrtModelEditorApi does not support getting the subgraphs for OrtNode");

onnxruntime/core/session/onnxruntime_c_api.cc

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include "core/graph/constants.h"
3232
#include "core/graph/graph.h"
3333
#include "core/graph/model_editor_api_types.h"
34+
#include "core/graph/ep_api_types.h"
3435
#include "core/providers/get_execution_providers.h"
3536
#include "core/session/abi_session_options_impl.h"
3637
#include "core/session/allocator_adapters.h"
@@ -2858,6 +2859,82 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetImplicitInputs, _In_ const OrtNode* node,
28582859
API_IMPL_END
28592860
}
28602861

2862+
ORT_API_STATUS_IMPL(OrtApis::Node_GetAttributes, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** attributes) {
2863+
API_IMPL_BEGIN
2864+
if (attributes == nullptr) {
2865+
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'attributes' argument is NULL");
2866+
}
2867+
2868+
std::unique_ptr<OrtArrayOfConstObjects> array;
2869+
ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetAttributes(array));
2870+
2871+
*attributes = array.release();
2872+
return nullptr;
2873+
API_IMPL_END
2874+
}
2875+
2876+
ORT_API_STATUS_IMPL(OrtApis::Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name, _Outptr_ const OrtOpAttr** attribute) {
2877+
API_IMPL_BEGIN
2878+
if (attribute == nullptr) {
2879+
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'attribute' argument is NULL");
2880+
}
2881+
2882+
const EpNode* ep_node = EpNode::ToInternal(node);
2883+
if (ep_node == nullptr) {
2884+
return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "node is a ModelEditorNode which doesn't support Node_GetAttributeByName.");
2885+
}
2886+
2887+
*attribute = ep_node->GetAttribute(attribute_name);
2888+
2889+
if (*attribute) {
2890+
return nullptr;
2891+
} else {
2892+
return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Attribute does not exist.");
2893+
}
2894+
2895+
API_IMPL_END
2896+
}
2897+
2898+
ORT_API_STATUS_IMPL(OrtApis::OpAttr_GetType, _In_ const OrtOpAttr* attribute, _Out_ OrtOpAttrType* type) {
2899+
API_IMPL_BEGIN
2900+
const auto attr = attribute->attr_proto;
2901+
auto onnx_attr_type = attribute->attr_proto.type();
2902+
switch (onnx_attr_type) {
2903+
case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_UNDEFINED: {
2904+
*type = OrtOpAttrType::ORT_OP_ATTR_UNDEFINED;
2905+
break;
2906+
}
2907+
case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_INT: {
2908+
*type = OrtOpAttrType::ORT_OP_ATTR_INT;
2909+
break;
2910+
}
2911+
case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_INTS: {
2912+
*type = OrtOpAttrType::ORT_OP_ATTR_INTS;
2913+
break;
2914+
}
2915+
case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_FLOAT: {
2916+
*type = OrtOpAttrType::ORT_OP_ATTR_FLOAT;
2917+
break;
2918+
}
2919+
case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_FLOATS: {
2920+
*type = OrtOpAttrType::ORT_OP_ATTR_FLOATS;
2921+
break;
2922+
}
2923+
case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_STRING: {
2924+
*type = OrtOpAttrType::ORT_OP_ATTR_STRING;
2925+
break;
2926+
}
2927+
case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_STRINGS: {
2928+
*type = OrtOpAttrType::ORT_OP_ATTR_STRINGS;
2929+
break;
2930+
}
2931+
default:
2932+
return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Unexpected attribute type.");
2933+
}
2934+
return nullptr;
2935+
API_IMPL_END
2936+
}
2937+
28612938
ORT_API_STATUS_IMPL(OrtApis::Node_GetSubgraphs, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** subgraphs) {
28622939
API_IMPL_BEGIN
28632940
if (subgraphs == nullptr) {
@@ -3558,6 +3635,9 @@ static constexpr OrtApi ort_api_1_to_23 = {
35583635
&OrtApis::Node_GetInputs,
35593636
&OrtApis::Node_GetOutputs,
35603637
&OrtApis::Node_GetImplicitInputs,
3638+
&OrtApis::Node_GetAttributes,
3639+
&OrtApis::Node_GetAttributeByName,
3640+
&OrtApis::OpAttr_GetType,
35613641
&OrtApis::Node_GetSubgraphs,
35623642
&OrtApis::Node_GetParentGraph,
35633643

onnxruntime/core/session/ort_apis.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,9 @@ ORT_API_STATUS_IMPL(Node_GetSinceVersion, _In_ const OrtNode* node, _Out_ int* s
662662
ORT_API_STATUS_IMPL(Node_GetInputs, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** inputs);
663663
ORT_API_STATUS_IMPL(Node_GetOutputs, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** outputs);
664664
ORT_API_STATUS_IMPL(Node_GetImplicitInputs, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** implicit_inputs);
665+
ORT_API_STATUS_IMPL(Node_GetAttributes, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** attrs);
666+
ORT_API_STATUS_IMPL(Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name, _Outptr_ const OrtOpAttr** attribute);
667+
ORT_API_STATUS_IMPL(OpAttr_GetType, _In_ const OrtOpAttr* attribute, _Out_ OrtOpAttrType* type);
665668
ORT_API_STATUS_IMPL(Node_GetSubgraphs, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** subgraphs);
666669
ORT_API_STATUS_IMPL(Node_GetParentGraph, _In_ const OrtNode* node,
667670
_Outptr_result_maybenull_ const OrtGraph** parent_graph);

onnxruntime/test/ep_graph/test_ep_graph.cc

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,75 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_
452452

453453
CheckValueInfosCApi(graph_viewer, api_node_outputs, output_node_args);
454454

455+
// Check node attributes
456+
const auto& node_attrs = node->GetAttributes();
457+
458+
if (node_attrs.size() > 0) {
459+
OrtArrayOfConstObjects* api_node_attributes = nullptr;
460+
DeferOrtRelease<OrtArrayOfConstObjects> release_node_attributes(&api_node_attributes,
461+
ort_api.ReleaseArrayOfConstObjects);
462+
ASSERT_ORTSTATUS_OK(ort_api.Node_GetAttributes(api_node, &api_node_attributes));
463+
CheckArrayObjectType(api_node_attributes, ORT_TYPE_TAG_OrtOpAttr);
464+
465+
size_t attr_idx = 0;
466+
for (const auto& node_attr : node_attrs) {
467+
const OrtOpAttr* api_node_attr = nullptr;
468+
ASSERT_ORTSTATUS_OK(ort_api.ArrayOfConstObjects_GetElementAt(api_node_attributes, attr_idx,
469+
reinterpret_cast<const void**>(&api_node_attr)));
470+
ASSERT_NE(api_node_attr, nullptr);
471+
472+
api_node_attr = nullptr;
473+
ASSERT_ORTSTATUS_OK(ort_api.Node_GetAttributeByName(api_node, node_attr.first.c_str(), &api_node_attr));
474+
ASSERT_NE(api_node_attr, nullptr);
475+
476+
OrtOpAttrType api_node_attr_type = OrtOpAttrType::ORT_OP_ATTR_UNDEFINED;
477+
478+
// It's possible that the type is defined in ONNX::AttributeProto_AttributeType but not in OrtOpAttrType, since the two are not in a 1:1 mapping.
479+
// In such cases, OpAttr_GetType will return a non-null status, and we simply skip the check here.
480+
OrtStatusPtr status = ort_api.OpAttr_GetType(api_node_attr, &api_node_attr_type);
481+
if (status != nullptr) {
482+
Ort::GetApi().ReleaseStatus(status);
483+
continue;
484+
}
485+
486+
ONNX_NAMESPACE::AttributeProto_AttributeType node_attr_type = node_attr.second.type();
487+
switch (node_attr_type) {
488+
case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_UNDEFINED: {
489+
ASSERT_EQ(api_node_attr_type, OrtOpAttrType::ORT_OP_ATTR_UNDEFINED);
490+
break;
491+
}
492+
case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_INT: {
493+
ASSERT_EQ(api_node_attr_type, OrtOpAttrType::ORT_OP_ATTR_INT);
494+
break;
495+
}
496+
case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_INTS: {
497+
ASSERT_EQ(api_node_attr_type, OrtOpAttrType::ORT_OP_ATTR_INTS);
498+
break;
499+
}
500+
case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_FLOAT: {
501+
ASSERT_EQ(api_node_attr_type, OrtOpAttrType::ORT_OP_ATTR_FLOAT);
502+
break;
503+
}
504+
case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_FLOATS: {
505+
ASSERT_EQ(api_node_attr_type, OrtOpAttrType::ORT_OP_ATTR_FLOATS);
506+
break;
507+
}
508+
case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_STRING: {
509+
ASSERT_EQ(api_node_attr_type, OrtOpAttrType::ORT_OP_ATTR_STRING);
510+
break;
511+
}
512+
case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_STRINGS: {
513+
ASSERT_EQ(api_node_attr_type, OrtOpAttrType::ORT_OP_ATTR_STRINGS);
514+
break;
515+
}
516+
default:
517+
// The unsupported type should be skipped by 'continue' above. It's unexpected so we force test to fail.
518+
ASSERT_ORTSTATUS_OK(ort_api.CreateStatus(ORT_FAIL, "The attribute type is not in AttributeProto_AttributeType and this case shouldn't be hit."));
519+
}
520+
attr_idx++;
521+
}
522+
}
523+
455524
// Check node subgraphs
456525
std::vector<gsl::not_null<const Graph*>> node_subgraphs = node->GetSubgraphs();
457526

0 commit comments

Comments
 (0)