Skip to content

Commit 554fb4a

Browse files
zz002wcy123Zhenze
authored
[VitisAI EP] export InferShapes to VitisAIEP (microsoft#23881)
### Description [VitisAI EP] export InferShapes to VitisAIEP --------- Co-authored-by: Wang Chunye <[email protected]> Co-authored-by: Zhenze <[email protected]>
1 parent b803429 commit 554fb4a

File tree

6 files changed

+31
-2
lines changed

6 files changed

+31
-2
lines changed

onnxruntime/core/providers/shared_library/provider_interfaces.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,8 @@ struct ProviderHost {
611611
virtual int FunctionProto__metadata_props_size(const ONNX_NAMESPACE::FunctionProto* p) = 0;
612612
virtual ONNX_NAMESPACE::StringStringEntryProto* FunctionProto__add_metadata_props(ONNX_NAMESPACE::FunctionProto* p) = 0;
613613

614+
virtual void InferShapes(const std::string& m, const std::string& save_path) = 0;
615+
virtual void InferShapes(ONNX_NAMESPACE::ModelProto& m) = 0;
614616
virtual void RegisterSchema(const std::string& domain, const OrtCustomOp* op) = 0;
615617
virtual void DeregisterSchema(const std::string& domain, const std::string& op_type, int version) = 0;
616618
virtual const ONNX_NAMESPACE::OpSchema* GetSchema(const std::string& name, const int maxInclusiveVersion, const std::string& domain) = 0;
@@ -1010,6 +1012,7 @@ struct ProviderHost {
10101012
virtual const Graph* Graph__ParentGraph(const Graph* p) const = 0;
10111013
virtual Graph* Graph__MutableParentGraph(Graph* p) = 0;
10121014
virtual const std::string& Graph__Name(const Graph* p) const noexcept = 0;
1015+
virtual void Graph__SetName(Graph* p, const std::string& name) const noexcept = 0;
10131016
virtual const std::filesystem::path& Graph__ModelPath(const Graph* p) const = 0;
10141017
virtual const std::vector<const NodeArg*>& Graph__GetInputsIncludingInitializers(const Graph* p) const noexcept = 0;
10151018
virtual bool Graph__IsSubgraph(const Graph* p) = 0;

onnxruntime/core/providers/shared_library/provider_wrappedtypes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,6 +1050,7 @@ struct Graph final {
10501050
const Graph* ParentGraph() const { return g_host->Graph__ParentGraph(this); }
10511051
Graph* MutableParentGraph() { return g_host->Graph__MutableParentGraph(this); }
10521052
const std::string& Name() const noexcept { return g_host->Graph__Name(this); }
1053+
void SetName(const std::string& name) noexcept { return g_host->Graph__SetName(this, name); }
10531054
const std::filesystem::path& ModelPath() const { return g_host->Graph__ModelPath(this); }
10541055
const std::vector<const NodeArg*>& GetInputsIncludingInitializers() const noexcept { return g_host->Graph__GetInputsIncludingInitializers(this); }
10551056
bool IsSubgraph() const { return g_host->Graph__IsSubgraph(this); }

onnxruntime/core/providers/vitisai/imp/global_api.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,10 +360,19 @@ vaip_core::OrtApiForVaip* create_org_api_hook() {
360360
};
361361
the_global_api.graph_nodes_unsafe = [](const Graph& graph) -> auto { return vaip_core::DllSafe(graph.Nodes()); };
362362
the_global_api.graph_get_name = [](const Graph& graph) -> const std::string& { return graph.Name(); };
363+
the_global_api.graph_set_name = [](Graph& graph, const char* name) -> void { return graph.SetName(std::string(name)); };
363364
the_global_api.graph_reverse_dfs_from = [](const Graph& graph, gsl::span<const Node* const> from,
364365
const auto& enter, const auto& leave, const auto& stop) {
365366
graph.ReverseDFSFrom(from, enter, leave, nullptr, stop);
366367
};
368+
369+
the_global_api.graph_infer_shapes_from_filepath = [](const std::string& m, const std::string& save_path) -> auto { return Provider_GetHost()->InferShapes(m, save_path); };
370+
the_global_api.graph_to_graph_proto = [](const Graph& graph) -> ONNX_NAMESPACE::GraphProto* {
371+
return graph.ToGraphProto().release();
372+
};
373+
the_global_api.graph_proto_delete = [](ONNX_NAMESPACE::GraphProto* p) { delete p; };
374+
the_global_api.graph_infer_shapes = [](ONNX_NAMESPACE::ModelProto& m) -> auto { return Provider_GetHost()->InferShapes(m); };
375+
367376
// node
368377
the_global_api.node_get_inputs_unsafe = vaip::node_get_inputs;
369378
the_global_api.node_get_output_node_args_unsafe = vaip::node_get_output_node_args;

onnxruntime/core/providers/vitisai/include/vaip/my_ort.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ struct NodeAttributes;
2020
namespace ONNX_NAMESPACE {
2121
struct AttributeProto;
2222
struct TensorProto;
23+
struct GraphProto;
2324
struct ModelProto;
2425
#ifndef USE_VITISAI
2526
enum TensorProto_DataType : int {
@@ -71,6 +72,7 @@ enum AttributeProto_AttributeType : int {
7172
namespace vaip_core {
7273
class GraphHolder;
7374
using ONNX_NAMESPACE::AttributeProto;
75+
using ONNX_NAMESPACE::GraphProto;
7476
using ONNX_NAMESPACE::ModelProto;
7577
using ONNX_NAMESPACE::TensorProto;
7678
using onnxruntime::Graph;

onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ struct OrtApi;
1313

1414
namespace vaip_core {
1515

16-
#define VAIP_ORT_API_MAJOR (14u)
16+
#define VAIP_ORT_API_MAJOR (16u)
1717
#define VAIP_ORT_API_MINOR (0u)
1818
#define VAIP_ORT_API_PATCH (0u)
1919
struct OrtApiForVaip {
@@ -249,7 +249,13 @@ struct OrtApiForVaip {
249249
const std::function<bool(const Node*)>& leave,
250250
const std::function<bool(const Node*, const Node*)>& comp,
251251
const std::function<bool(const Node* from, const Node* to)>&
252-
stop); // [103]
252+
stop); // [103]
253+
void (*graph_set_name)(Graph& graph, const char* name); // [104]
254+
void (*graph_infer_shapes_from_filepath)(
255+
const std::string& m, const std::string& save_path); // [105]
256+
GraphProto* (*graph_to_graph_proto)(const Graph& graph); // [106]
257+
void (*graph_proto_delete)(GraphProto* p); // [107]
258+
void (*graph_infer_shapes)(ModelProto& m); // [108]
253259
};
254260

255261
#ifndef USE_VITISAI

onnxruntime/core/session/provider_bridge_ort.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
#include "core/session/onnxruntime_c_api.h"
4444
#include "core/common/string_helper.h"
4545
#include <utility>
46+
#include "onnx/shape_inference/implementation.h"
4647

4748
#ifdef ENABLE_TRAINING
4849
#ifdef ENABLE_TRAINING_TORCH_INTEROP
@@ -771,6 +772,12 @@ struct ProviderHostImpl : ProviderHost {
771772
int FunctionProto__metadata_props_size(const ONNX_NAMESPACE::FunctionProto* p) override { return p->metadata_props_size(); }
772773
ONNX_NAMESPACE::StringStringEntryProto* FunctionProto__add_metadata_props(ONNX_NAMESPACE::FunctionProto* p) override { return p->add_metadata_props(); }
773774

775+
void InferShapes(const std::string& m, const std::string& save_path) override {
776+
return ONNX_NAMESPACE::shape_inference::InferShapes(m, save_path);
777+
}
778+
void InferShapes(ONNX_NAMESPACE::ModelProto& m) override {
779+
return ONNX_NAMESPACE::shape_inference::InferShapes(m);
780+
}
774781
void RegisterSchema(const std::string& domain, const OrtCustomOp* op) override {
775782
auto& domain_instance = ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance();
776783
const auto& domain_to_version_map = domain_instance.Map();
@@ -1268,6 +1275,7 @@ struct ProviderHostImpl : ProviderHost {
12681275
const Graph* Graph__ParentGraph(const Graph* p) const override { return p->ParentGraph(); }
12691276
Graph* Graph__MutableParentGraph(Graph* p) override { return p->MutableParentGraph(); }
12701277
const std::string& Graph__Name(const Graph* p) const noexcept override { return p->Name(); }
1278+
void Graph__SetName(Graph* p, const std::string& name) const noexcept override { return p->SetName(name); }
12711279
const std::filesystem::path& Graph__ModelPath(const Graph* p) const override { return p->ModelPath(); }
12721280
const std::vector<const NodeArg*>& Graph__GetInputsIncludingInitializers(const Graph* p) const noexcept override { return p->GetInputsIncludingInitializers(); }
12731281
bool Graph__IsSubgraph(const Graph* p) override { return p->IsSubgraph(); }

0 commit comments

Comments
 (0)