Skip to content

Commit 0e5e9bd

Browse files
committed
Handle aliasing of "" domain and "ai.onnx" domain.
According to the onnx documentation, the default domain is "ai.onnx", but in practice it seems like the empty-string domain is used by default. Add a canonicalization function to handle them in the same way. Signed-off-by: Rickert, Jonas <[email protected]>
1 parent 17c0b81 commit 0e5e9bd

File tree

2 files changed

+46
-12
lines changed

2 files changed

+46
-12
lines changed

src/Builder/FrontendDialectTransformer.cpp

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,13 @@ bool isDefaultDomain(std::string_view domain) {
6868
return domain.empty() || (domain == "ai.onnx");
6969
}
7070

71+
std::string canonicalizeDomain(std::string_view domain) {
72+
// Handle aliasing of "ai.onnx" and "". According to the onnx documentation,
73+
// the default domain is "ai.onnx", but in practice it seems like the
74+
// empty-string domain is used by default.
75+
return isDefaultDomain(domain) ? "" : std::string(domain);
76+
}
77+
7178
/// We consider opset < 6 is old. Users will see a warning if their model
7279
/// contains ops of old opset.
7380
constexpr int32_t MINIMUM_SUPPORTED_OPSET = 6;
@@ -86,7 +93,13 @@ template <class T>
8693
OpsetImportsMap GetOpsetImportsFromProto(const T &proto) {
8794
OpsetImportsMap opset_imports;
8895
for (const auto &opset_import : proto.opset_import()) {
89-
opset_imports[opset_import.domain()] = opset_import.version();
96+
const auto domain = canonicalizeDomain(opset_import.domain());
97+
const auto [iter_, inserted] =
98+
opset_imports.emplace(domain, opset_import.version());
99+
if (!inserted) {
100+
llvm::errs() << "Warning: Domain " << domain
101+
<< " found multiple times in opset imports.\n";
102+
}
90103
}
91104
return opset_imports;
92105
}
@@ -101,7 +114,8 @@ ModelLocalFunctionsMap GetModelLocalFunctions(const onnx::ModelProto &m) {
101114
for (const auto &function_proto : m.functions()) {
102115
model_local_functions_by_id.insert(
103116
{GetModelLocalFunctionsMapIdentifier(
104-
function_proto.domain(), function_proto.name()),
117+
canonicalizeDomain(function_proto.domain()),
118+
function_proto.name()),
105119
&function_proto});
106120
}
107121
return model_local_functions_by_id;
@@ -1298,18 +1312,19 @@ class FrontendGenImpl {
12981312
return onnx::OpSchemaRegistry::Schema(node.op_type(), version, domain);
12991313
}
13001314

1301-
std::string GetImportVersionOfNode(const onnx::NodeProto &node) {
1302-
auto current_opset_it = opset_map_.find(node.domain());
1315+
std::string GetImportVersionOfNode(
1316+
const onnx::NodeProto &node, const std::string &domain) {
1317+
auto current_opset_it = opset_map_.find(domain);
13031318
if (current_opset_it == opset_map_.end())
13041319
return "";
13051320

13061321
const int current_opset = current_opset_it->second;
13071322

1308-
const auto op_version_map = dialect_op_version_map_.find(node.domain());
1323+
const auto op_version_map = dialect_op_version_map_.find(domain);
13091324
if (op_version_map == dialect_op_version_map_.end())
13101325
return "";
13111326

1312-
const auto op_opsets_map = dialect_op_opsets_map_.find(node.domain());
1327+
const auto op_opsets_map = dialect_op_opsets_map_.find(domain);
13131328
if (op_opsets_map == dialect_op_opsets_map_.end())
13141329
return "";
13151330

@@ -1363,7 +1378,7 @@ class FrontendGenImpl {
13631378

13641379
// A new opset is added to onnx-mlir when it becomes incompatible.
13651380
// All opset newest than the last opset should use the last opset(version)
1366-
if (isDefaultDomain(node.domain()) &&
1381+
if (isDefaultDomain(domain) &&
13671382
upperRangeOfNewestValidOpsetVersion < supported_opset_list.back() &&
13681383
upperRangeOfNewestValidOpsetVersion < MINIMUM_SUPPORTED_OPSET)
13691384
llvm::errs() << "\nWarning: ONNX " << node.op_type()
@@ -1592,8 +1607,8 @@ class FrontendGenImpl {
15921607
auto mlirAttr = builder_.getStringAttr(funcName);
15931608
auto funcAttr = builder_.getNamedAttr("function_name", mlirAttr);
15941609
attributes.push_back(funcAttr);
1595-
auto domainAttr = builder_.getNamedAttr(
1596-
"domain_name", builder_.getStringAttr(node.domain()));
1610+
auto domainAttr = builder_.getNamedAttr("domain_name",
1611+
builder_.getStringAttr(canonicalizeDomain(node.domain())));
15971612
attributes.push_back(domainAttr);
15981613
const int nIn = ONNXCustomOp::getNumberOfOperands();
15991614
const int nOut = ONNXCustomOp::getNumberOfResults();
@@ -1635,9 +1650,11 @@ class FrontendGenImpl {
16351650

16361651
[[nodiscard]] std::error_code ImportNode(
16371652
const onnx::NodeProto &node, std::string &errorMessage) {
1653+
const std::string domain = canonicalizeDomain(node.domain());
16381654

1639-
std::string opName = node.op_type() + GetImportVersionOfNode(node);
1640-
auto domainIt = import_handler_map_.find(node.domain());
1655+
const std::string opName =
1656+
node.op_type() + GetImportVersionOfNode(node, domain);
1657+
auto domainIt = import_handler_map_.find(domain);
16411658
if (domainIt != import_handler_map_.end()) {
16421659
auto handler = domainIt->second.find(opName);
16431660
std::vector<std::string> funcs = options_.functionsToDecompose;
@@ -1658,7 +1675,7 @@ class FrontendGenImpl {
16581675
}
16591676

16601677
auto model_function = in_model_functions_.find(
1661-
GetModelLocalFunctionsMapIdentifier(node.domain(), node.op_type()));
1678+
GetModelLocalFunctionsMapIdentifier(domain, node.op_type()));
16621679
if (model_function != in_model_functions_.end()) {
16631680
return ImportFunctionCallNode(
16641681
node, /*schema=*/nullptr, model_function->second, errorMessage);
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// RUN: onnx-mlir --EmitONNXBasic --printIR %s | FileCheck %s
2+
3+
// Based on test_abs.onnxtext
4+
5+
<
6+
ir_version: 7,
7+
opset_import: ["" : 13],
8+
producer_name: "backend-test"
9+
>
10+
test_abs (float[3,4,5] x) => (float[3,4,5] y) {
11+
y = ai.onnx.Abs (x)
12+
}
13+
// CHECK-LABEL: func.func @main_graph
14+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4x5xf32> {onnx.name = "x"}) -> (tensor<3x4x5xf32> {onnx.name = "y"}) {
15+
// CHECK: [[VAR_0_:%.+]] = "onnx.Abs"([[PARAM_0_]]) : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
16+
// CHECK: onnx.Return [[VAR_0_]] : tensor<3x4x5xf32>
17+
// CHECK: }

0 commit comments

Comments
 (0)