Skip to content

Commit d6a01d2

Browse files
authored
Merge pull request #477 from Xilinx/jrickert.ai_onnx_domain
Handle aliasing of "" domain and "ai.onnx" domain.
2 parents 17c0b81 + 0e5e9bd commit d6a01d2

File tree

2 files changed

+46
-12
lines changed

2 files changed

+46
-12
lines changed

src/Builder/FrontendDialectTransformer.cpp

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,13 @@ bool isDefaultDomain(std::string_view domain) {
6868
return domain.empty() || (domain == "ai.onnx");
6969
}
7070

71+
std::string canonicalizeDomain(std::string_view domain) {
72+
// Handle aliasing of "ai.onnx" and "". According to the onnx documentation,
73+
// the default domain is "ai.onnx", but in practice it seems like the
74+
// empty-string domain is used by default.
75+
return isDefaultDomain(domain) ? "" : std::string(domain);
76+
}
77+
7178
/// We consider opset < 6 is old. Users will see a warning if their model
7279
/// contains ops of old opset.
7380
constexpr int32_t MINIMUM_SUPPORTED_OPSET = 6;
@@ -86,7 +93,13 @@ template <class T>
8693
OpsetImportsMap GetOpsetImportsFromProto(const T &proto) {
8794
OpsetImportsMap opset_imports;
8895
for (const auto &opset_import : proto.opset_import()) {
89-
opset_imports[opset_import.domain()] = opset_import.version();
96+
const auto domain = canonicalizeDomain(opset_import.domain());
97+
const auto [iter_, inserted] =
98+
opset_imports.emplace(domain, opset_import.version());
99+
if (!inserted) {
100+
llvm::errs() << "Warning: Domain " << domain
101+
<< " found multiple times in opset imports.\n";
102+
}
90103
}
91104
return opset_imports;
92105
}
@@ -101,7 +114,8 @@ ModelLocalFunctionsMap GetModelLocalFunctions(const onnx::ModelProto &m) {
101114
for (const auto &function_proto : m.functions()) {
102115
model_local_functions_by_id.insert(
103116
{GetModelLocalFunctionsMapIdentifier(
104-
function_proto.domain(), function_proto.name()),
117+
canonicalizeDomain(function_proto.domain()),
118+
function_proto.name()),
105119
&function_proto});
106120
}
107121
return model_local_functions_by_id;
@@ -1298,18 +1312,19 @@ class FrontendGenImpl {
12981312
return onnx::OpSchemaRegistry::Schema(node.op_type(), version, domain);
12991313
}
13001314

1301-
std::string GetImportVersionOfNode(const onnx::NodeProto &node) {
1302-
auto current_opset_it = opset_map_.find(node.domain());
1315+
std::string GetImportVersionOfNode(
1316+
const onnx::NodeProto &node, const std::string &domain) {
1317+
auto current_opset_it = opset_map_.find(domain);
13031318
if (current_opset_it == opset_map_.end())
13041319
return "";
13051320

13061321
const int current_opset = current_opset_it->second;
13071322

1308-
const auto op_version_map = dialect_op_version_map_.find(node.domain());
1323+
const auto op_version_map = dialect_op_version_map_.find(domain);
13091324
if (op_version_map == dialect_op_version_map_.end())
13101325
return "";
13111326

1312-
const auto op_opsets_map = dialect_op_opsets_map_.find(node.domain());
1327+
const auto op_opsets_map = dialect_op_opsets_map_.find(domain);
13131328
if (op_opsets_map == dialect_op_opsets_map_.end())
13141329
return "";
13151330

@@ -1363,7 +1378,7 @@ class FrontendGenImpl {
13631378

13641379
// A new opset is added to onnx-mlir when it becomes incompatible.
13651380
// All opset newest than the last opset should use the last opset(version)
1366-
if (isDefaultDomain(node.domain()) &&
1381+
if (isDefaultDomain(domain) &&
13671382
upperRangeOfNewestValidOpsetVersion < supported_opset_list.back() &&
13681383
upperRangeOfNewestValidOpsetVersion < MINIMUM_SUPPORTED_OPSET)
13691384
llvm::errs() << "\nWarning: ONNX " << node.op_type()
@@ -1592,8 +1607,8 @@ class FrontendGenImpl {
15921607
auto mlirAttr = builder_.getStringAttr(funcName);
15931608
auto funcAttr = builder_.getNamedAttr("function_name", mlirAttr);
15941609
attributes.push_back(funcAttr);
1595-
auto domainAttr = builder_.getNamedAttr(
1596-
"domain_name", builder_.getStringAttr(node.domain()));
1610+
auto domainAttr = builder_.getNamedAttr("domain_name",
1611+
builder_.getStringAttr(canonicalizeDomain(node.domain())));
15971612
attributes.push_back(domainAttr);
15981613
const int nIn = ONNXCustomOp::getNumberOfOperands();
15991614
const int nOut = ONNXCustomOp::getNumberOfResults();
@@ -1635,9 +1650,11 @@ class FrontendGenImpl {
16351650

16361651
[[nodiscard]] std::error_code ImportNode(
16371652
const onnx::NodeProto &node, std::string &errorMessage) {
1653+
const std::string domain = canonicalizeDomain(node.domain());
16381654

1639-
std::string opName = node.op_type() + GetImportVersionOfNode(node);
1640-
auto domainIt = import_handler_map_.find(node.domain());
1655+
const std::string opName =
1656+
node.op_type() + GetImportVersionOfNode(node, domain);
1657+
auto domainIt = import_handler_map_.find(domain);
16411658
if (domainIt != import_handler_map_.end()) {
16421659
auto handler = domainIt->second.find(opName);
16431660
std::vector<std::string> funcs = options_.functionsToDecompose;
@@ -1658,7 +1675,7 @@ class FrontendGenImpl {
16581675
}
16591676

16601677
auto model_function = in_model_functions_.find(
1661-
GetModelLocalFunctionsMapIdentifier(node.domain(), node.op_type()));
1678+
GetModelLocalFunctionsMapIdentifier(domain, node.op_type()));
16621679
if (model_function != in_model_functions_.end()) {
16631680
return ImportFunctionCallNode(
16641681
node, /*schema=*/nullptr, model_function->second, errorMessage);
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// RUN: onnx-mlir --EmitONNXBasic --printIR %s | FileCheck %s
2+
3+
// Based on test_abs.onnxtext
4+
5+
<
6+
ir_version: 7,
7+
opset_import: ["" : 13],
8+
producer_name: "backend-test"
9+
>
10+
test_abs (float[3,4,5] x) => (float[3,4,5] y) {
11+
y = ai.onnx.Abs (x)
12+
}
13+
// CHECK-LABEL: func.func @main_graph
14+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4x5xf32> {onnx.name = "x"}) -> (tensor<3x4x5xf32> {onnx.name = "y"}) {
15+
// CHECK: [[VAR_0_:%.+]] = "onnx.Abs"([[PARAM_0_]]) : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
16+
// CHECK: onnx.Return [[VAR_0_]] : tensor<3x4x5xf32>
17+
// CHECK: }

0 commit comments

Comments
 (0)