Skip to content

Commit 319159b

Browse files
zz002Zhenze Wang
andauthored
[VitisAI]set-data_loaction-as-default-when-load-external-data (#19712)
### Description <!-- Describe your changes. --> set-data_loaction-as-default-when-load-external-data fix vitis ai ep can not get CutomOps by session_option register ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> VitisAI bug daily fixes when use pass: fuse_qdq_GEMM or fuse_qdq_MATMUL, get error like : Error Data of TensorProto ( tensor name: xxx) is stored externally and should not have data field.raw_data --------- Co-authored-by: Zhenze Wang <[email protected]>
1 parent 742595b commit 319159b

File tree

7 files changed

+32
-10
lines changed

7 files changed

+32
-10
lines changed

onnxruntime/core/providers/shared_library/provider_interfaces.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Licensed under the MIT License.
33

44
#include <optional>
5+
#include <list>
56

67
// Public wrappers around internal ort interfaces (currently)
78
#include "core/providers/shared_library/provider_host_api.h"
@@ -34,6 +35,7 @@ struct ProviderHostCPU;
3435
class PhiloxGenerator;
3536
using ProviderType = const std::string&;
3637
class RandomGenerator;
38+
class IOnnxRuntimeOpSchemaCollection;
3739

3840
#ifdef ENABLE_TRAINING_TORCH_INTEROP
3941
namespace contrib {
@@ -93,6 +95,8 @@ using NodeIndex = size_t;
9395
// using NodeAttributes = std::unordered_map<std::string, ONNX_NAMESPACE::AttributeProto_Copyable>;
9496
using ModelMetaData = std::unordered_map<std::string, std::string>;
9597

98+
using IOnnxRuntimeOpSchemaCollectionPtr = std::shared_ptr<IOnnxRuntimeOpSchemaCollection>;
99+
using IOnnxRuntimeOpSchemaRegistryList = std::list<IOnnxRuntimeOpSchemaCollectionPtr>;
96100
using InitializedTensorSet = std::unordered_map<std::string, const ONNX_NAMESPACE::TensorProto*>;
97101

98102
struct Node__NodeIterator {
@@ -435,6 +439,7 @@ struct ProviderHost {
435439
virtual void TensorProto__clear_int64_data(ONNX_NAMESPACE::TensorProto* p) = 0;
436440
virtual void TensorProto__clear_double_data(ONNX_NAMESPACE::TensorProto* p) = 0;
437441
virtual void TensorProto__clear_uint64_data(ONNX_NAMESPACE::TensorProto* p) = 0;
442+
virtual void TensorProto__set_data_location(ONNX_NAMESPACE::TensorProto* p, ONNX_NAMESPACE::TensorProto_DataLocation data_location) = 0;
438443

439444
virtual bool TensorProto_DataType_IsValid(int value) = 0;
440445

@@ -755,8 +760,9 @@ struct ProviderHost {
755760
virtual void NodeAttributes__reserve(NodeAttributes* p, size_t size) = 0;
756761

757762
// Model
758-
virtual std::unique_ptr<Model> Model__construct(ONNX_NAMESPACE::ModelProto&& model_proto,
759-
const PathString& model_path, const logging::Logger& logger) = 0;
763+
virtual std::unique_ptr<Model> Model__construct(ONNX_NAMESPACE::ModelProto&& model_proto, const PathString& model_path,
764+
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
765+
const logging::Logger& logger) = 0;
760766
virtual void Model__operator_delete(Model* p) = 0;
761767
virtual Graph& Model__MainGraph(Model* p) = 0;
762768
virtual std::unique_ptr<ONNX_NAMESPACE::ModelProto> Model__ToProto(Model* p) = 0;
@@ -814,6 +820,7 @@ struct ProviderHost {
814820
virtual Node* Graph__GetNode(Graph* p, NodeIndex node_index) noexcept = 0;
815821
virtual const Node* Graph__GetNode(const Graph* p, NodeIndex node_index) const = 0;
816822
virtual const NodeArg* Graph__GetNodeArg(const Graph* p, const std::string& name) const = 0;
823+
virtual IOnnxRuntimeOpSchemaCollectionPtr Graph__GetSchemaRegistry(const Graph* p) const = 0;
817824

818825
// GraphViewer
819826
virtual void GraphViewer__operator_delete(GraphViewer* p) = 0;

onnxruntime/core/providers/shared_library/provider_wrappedtypes.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ struct TensorProto final {
205205

206206
bool has_data_location() const { return g_host->TensorProto__has_data_location(this); }
207207
TensorProto_DataLocation data_location() const { return TensorProto_DataLocation(g_host->TensorProto__data_location(this)); }
208+
void set_data_location(TensorProto_DataLocation data_location) { return g_host->TensorProto__set_data_location(this, data_location); }
208209

209210
bool has_raw_data() const { return g_host->TensorProto__has_raw_data(this); }
210211
const std::string& raw_data() const { return g_host->TensorProto__raw_data(this); }
@@ -778,8 +779,8 @@ struct NodeAttributes final {
778779

779780
struct Model final {
780781
static std::unique_ptr<Model> Create(ONNX_NAMESPACE::ModelProto&& model_proto, const PathString& model_path,
781-
const logging::Logger& logger) {
782-
return g_host->Model__construct(std::move(model_proto), model_path, logger);
782+
const IOnnxRuntimeOpSchemaRegistryList* local_registries, const logging::Logger& logger) {
783+
return g_host->Model__construct(std::move(model_proto), model_path, local_registries, logger);
783784
}
784785
static void operator delete(void* p) { g_host->Model__operator_delete(reinterpret_cast<Model*>(p)); }
785786
static Status Load(const PathString& file_path, /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) { return g_host->Model__Load(file_path, model_proto); }
@@ -857,6 +858,7 @@ struct Graph final {
857858
const Node* GetNode(NodeIndex node_index) const noexcept { return g_host->Graph__GetNode(this, node_index); }
858859
Node* GetNode(NodeIndex node_index) noexcept { return g_host->Graph__GetNode(this, node_index); }
859860
const NodeArg* GetNodeArg(const std::string& name) const { return g_host->Graph__GetNodeArg(this, name); }
861+
IOnnxRuntimeOpSchemaCollectionPtr GetSchemaRegistry() const { return g_host->Graph__GetSchemaRegistry(this); }
860862

861863
PROVIDER_DISALLOW_ALL(Graph)
862864
};

onnxruntime/core/providers/vitisai/imp/global_api.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ vaip_core::OrtApiForVaip* create_org_api_hook() {
188188
auto file_path = ToPathString(filename);
189189
auto status = Model::Load(file_path, *model_proto);
190190
vai_assert(status.IsOK(), "load model proto error");
191-
auto model = Model::Create(std::move(*model_proto), file_path, logger);
191+
auto model = Model::Create(std::move(*model_proto), file_path, nullptr, logger);
192192
return model.release();
193193
};
194194
the_global_api.model_delete = [](Model* model) { delete model; };
@@ -198,7 +198,8 @@ vaip_core::OrtApiForVaip* create_org_api_hook() {
198198
auto& model = const_cast<onnxruntime::Model&>(const_model);
199199
auto model_proto = model.ToProto();
200200
auto file_path = model.MainGraph().ModelPath().ToPathString();
201-
auto ret = Model::Create(std::move(*model_proto), file_path, logger);
201+
auto local_registries = IOnnxRuntimeOpSchemaRegistryList{model.MainGraph().GetSchemaRegistry()};
202+
auto ret = Model::Create(std::move(*model_proto), file_path, &local_registries, logger);
202203
auto status = ret->MainGraph().Resolve();
203204
vai_assert(status.IsOK(), status.ErrorMessage());
204205
return ret.release();

onnxruntime/core/providers/vitisai/imp/node.cc

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,17 @@ vaip_core::DllSafe<std::vector<const NodeArg*>> node_get_output_node_args(const
3434
auto ret = std::vector<const NodeArg*>(size);
3535
for (auto i = 0u; i < size; ++i) {
3636
auto output = outputs[i];
37-
ret[i] = output;
3837
assert(output != nullptr);
39-
vai_assert(output->Exists(), std::string("output must exists. name=" + output->Name()));
38+
// Optional Outputs
39+
// Some operators have outputs that are optional. When an actual output parameter of an operator is not specified, the operator implementation MAY forgo computing values for such outputs.
40+
// There are two ways to leave an optional input or output unspecified: the first, available only for trailing inputs and outputs, is to simply not provide that input; the second method is to use an empty string in place of an input or output name.
41+
// so optional output maybe output != null && output->Exists() return false
42+
// Our processing : nullptr means optional output , and clinet code needs to handle nullptr
43+
if (output->Exists()) {
44+
ret[i] = output;
45+
} else {
46+
ret[i] = nullptr;
47+
}
4048
}
4149
return vaip_core::DllSafe(ret);
4250
}

onnxruntime/core/providers/vitisai/imp/tensor_proto.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ gsl::span<const char> tensor_proto_as_raw(const ONNX_NAMESPACE::TensorProto& ten
2222
mut_tensor.clear_double_data();
2323
mut_tensor.clear_uint64_data();
2424
memcpy(mut_tensor.mutable_raw_data()->data(), unpacked_tensor.data(), unpacked_tensor.size());
25+
mut_tensor.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_DEFAULT);
2526
}
2627
return gsl::span<const char>(tensor.raw_data().data(), tensor.raw_data().size());
2728
}

onnxruntime/core/session/provider_bridge_ort.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,7 @@ struct ProviderHostImpl : ProviderHost {
528528
void TensorProto__add_dims(ONNX_NAMESPACE::TensorProto* p, int64_t value) override { p->add_dims(value); }
529529
bool TensorProto__has_data_location(const ONNX_NAMESPACE::TensorProto* p) override { return p->has_data_location(); }
530530
int TensorProto__data_location(const ONNX_NAMESPACE::TensorProto* p) override { return p->data_location(); }
531+
void TensorProto__set_data_location(ONNX_NAMESPACE::TensorProto* p, ONNX_NAMESPACE::TensorProto_DataLocation data_location) override { return p->set_data_location(data_location); }
531532
bool TensorProto__has_raw_data(const ONNX_NAMESPACE::TensorProto* p) override { return p->has_raw_data(); }
532533
const std::string& TensorProto__raw_data(const ONNX_NAMESPACE::TensorProto* p) override { return p->raw_data(); }
533534
std::string* TensorProto__mutable_raw_data(ONNX_NAMESPACE::TensorProto* p) override { return p->mutable_raw_data(); }
@@ -966,8 +967,9 @@ struct ProviderHostImpl : ProviderHost {
966967

967968
// Model (wrapped)
968969
std::unique_ptr<Model> Model__construct(ONNX_NAMESPACE::ModelProto&& model_proto, const PathString& model_path,
970+
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
969971
const logging::Logger& logger) override {
970-
return std::make_unique<Model>(model_proto, model_path, nullptr, logger);
972+
return std::make_unique<Model>(model_proto, model_path, local_registries, logger);
971973
}
972974
void Model__operator_delete(Model* p) override { delete p; }
973975
Graph& Model__MainGraph(Model* p) override { return p->MainGraph(); }
@@ -1047,6 +1049,7 @@ struct ProviderHostImpl : ProviderHost {
10471049
Node* Graph__GetNode(Graph* p, NodeIndex node_index) noexcept override { return p->GetNode(node_index); }
10481050
const Node* Graph__GetNode(const Graph* p, NodeIndex node_index) const override { return p->GetNode(node_index); }
10491051
const NodeArg* Graph__GetNodeArg(const Graph* p, const std::string& name) const override { return p->GetNodeArg(name); }
1052+
IOnnxRuntimeOpSchemaCollectionPtr Graph__GetSchemaRegistry(const Graph* p) const override { return p->GetSchemaRegistry(); }
10501053

10511054
// GraphViewer (wrapped)
10521055
void GraphViewer__operator_delete(GraphViewer* p) override { delete p; }

onnxruntime/test/perftest/ort_test_session.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -649,7 +649,7 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
649649
std::string value(token.substr(pos + 1));
650650
vitisai_session_options[key] = value;
651651
}
652-
session_options.AppendExecutionProvider("VitisAI", vitisai_session_options);
652+
session_options.AppendExecutionProvider_VitisAI(vitisai_session_options);
653653
#else
654654
ORT_THROW("VitisAI is not supported in this build\n");
655655
#endif

0 commit comments

Comments
 (0)