Skip to content

Commit 99c057f

Browse files
authored
Merge pull request #468 from Xilinx/matthias.attributes_on_constant
ONNXConstantOp: Allow extra attributes (parser/printer)
2 parents de1d2df + 5e2ee82 commit 99c057f

File tree

3 files changed

+52
-9
lines changed

3 files changed

+52
-9
lines changed

src/Dialect/ONNX/ONNXOps.cpp

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -104,15 +104,21 @@ void printNamedAttribute(OpAsmPrinter &printer, NamedAttribute namedAttr) {
104104
printAttribute(printer, namedAttr.getValue());
105105
}
106106

107-
void printOptionalAttrDict(
108-
OpAsmPrinter &printer, ArrayRef<NamedAttribute> attrs) {
107+
void printOptionalAttrDict(OpAsmPrinter &printer,
108+
ArrayRef<NamedAttribute> attrs, ArrayRef<StringRef> elidedAttrs = {}) {
109+
SmallVector<NamedAttribute, 4> filteredAttrs;
110+
for (NamedAttribute attr : attrs) {
111+
if (llvm::is_contained(elidedAttrs, attr.getName().strref()))
112+
continue;
113+
filteredAttrs.emplace_back(attr);
114+
}
109115
// If there are no attributes, then there is nothing to be done.
110-
if (attrs.empty())
116+
if (filteredAttrs.empty())
111117
return;
112118

113119
// Otherwise, print them all out in braces.
114120
printer << " {";
115-
llvm::interleaveComma(attrs, printer.getStream(),
121+
llvm::interleaveComma(filteredAttrs, printer.getStream(),
116122
[&](NamedAttribute attr) { printNamedAttribute(printer, attr); });
117123
printer << '}';
118124
}
@@ -156,6 +162,7 @@ void ONNXConstantOp::print(OpAsmPrinter &printer) {
156162
assert(!mlir::isa<SparseElementsAttr>(elements) &&
157163
"ONNXConstantOp value cannot be sparse");
158164
if (elements.getType() == resultType) {
165+
printOptionalAttrDict(printer, (*this)->getAttrs(), {"value"});
159166
printer << ' ';
160167
printAttribute(printer, elements);
161168
return;
@@ -165,6 +172,7 @@ void ONNXConstantOp::print(OpAsmPrinter &printer) {
165172
// ONNXConstantOp sparse_value must be SparseElementsAttr.
166173
auto sparseElements = mlir::cast<SparseElementsAttr>(*attr);
167174
if (sparseElements.getType() == resultType) {
175+
printOptionalAttrDict(printer, (*this)->getAttrs(), {"sparse_value"});
168176
printer << ' ';
169177
printer.printAttribute(sparseElements);
170178
return;
@@ -182,9 +190,10 @@ ParseResult ONNXConstantOp::parse(OpAsmParser &parser, OperationState &result) {
182190
// First try to parse attribute dictionary.
183191
if (parser.parseOptionalAttrDict(result.attributes))
184192
return failure();
185-
// If there is no attribute dictionary, the parse above succeeds parsing
186-
// nothing. We detect this case by the absence of result attributes.
187-
if (result.attributes.empty()) {
193+
// If value/sparse_value were not in the attributes, parse them from their
194+
// pretty form.
195+
if (!result.attributes.get("value") &&
196+
!result.attributes.get("sparse_value")) {
188197
// Try to parse a SparseElementsAttr or or other ElementsAttr.
189198
OptionalParseResult opt = parser.parseOptionalAttribute(attr, type);
190199
if (opt.has_value()) {

test/mlir/onnx/onnx_constant.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// RUN: onnx-mlir-opt %s -mlir-print-op-generic | FileCheck -check-prefix=GENERIC %s
2+
// RUN: onnx-mlir-opt %s | FileCheck %s
3+
4+
func.func @test() {
5+
%generic = "onnx.Constant"() {value = dense<-1> : tensor<1xi64> } : () -> tensor<1xi64>
6+
%pretty = onnx.Constant dense<-1> : tensor<1xi64>
7+
8+
%generic_with_extra_attr = "onnx.Constant"() {value = dense<-1> : tensor<1xi64>, extra_attr = 0 : i32 } : () -> tensor<1xi64>
9+
%pretty_with_extra_attr = onnx.Constant {extra_attr = 0 : i32} dense<-1> : tensor<1xi64>
10+
11+
%generic_dynamic = "onnx.Constant"() {value = dense<-1> : tensor<1xi64> } : () -> tensor<*xi64>
12+
%pretty_dynamic = onnx.Constant {value = dense<-1> : tensor<1xi64>} : tensor<*xi64>
13+
14+
%generic_dynamic_with_extra_attr = "onnx.Constant"() {value = dense<-1> : tensor<1xi64>, extra_attr = 0 : i32 } : () -> tensor<*xi64>
15+
%pretty_dynamic_with_extra_attr = onnx.Constant {extra_attr = 0 : i32, value = dense<-1> : tensor<1xi64>} : tensor<*xi64>
16+
func.return
17+
}
18+
19+
// GENERIC: "onnx.Constant"() {value = dense<-1> : tensor<1xi64>} : () -> tensor<1xi64>
20+
// GENERIC: "onnx.Constant"() {value = dense<-1> : tensor<1xi64>} : () -> tensor<1xi64>
21+
// GENERIC: "onnx.Constant"() {extra_attr = 0 : i32, value = dense<-1> : tensor<1xi64>} : () -> tensor<1xi64>
22+
// GENERIC: "onnx.Constant"() {extra_attr = 0 : i32, value = dense<-1> : tensor<1xi64>} : () -> tensor<1xi64>
23+
// GENERIC: "onnx.Constant"() {value = dense<-1> : tensor<1xi64>} : () -> tensor<*xi64>
24+
// GENERIC: "onnx.Constant"() {value = dense<-1> : tensor<1xi64>} : () -> tensor<*xi64>
25+
// GENERIC: "onnx.Constant"() {extra_attr = 0 : i32, value = dense<-1> : tensor<1xi64>} : () -> tensor<*xi64>
26+
// GENERIC: "onnx.Constant"() {extra_attr = 0 : i32, value = dense<-1> : tensor<1xi64>} : () -> tensor<*xi64>
27+
// CHECK: onnx.Constant dense<-1> : tensor<1xi64>
28+
// CHECK: onnx.Constant dense<-1> : tensor<1xi64>
29+
// CHECK: onnx.Constant {extra_attr = 0 : i32} dense<-1> : tensor<1xi64>
30+
// CHECK: onnx.Constant {extra_attr = 0 : i32} dense<-1> : tensor<1xi64>
31+
// CHECK: onnx.Constant {value = dense<-1> : tensor<1xi64>} : tensor<*xi64>
32+
// CHECK: onnx.Constant {value = dense<-1> : tensor<1xi64>} : tensor<*xi64>
33+
// CHECK: onnx.Constant {extra_attr = 0 : i32, value = dense<-1> : tensor<1xi64>} : tensor<*xi64>
34+
// CHECK: onnx.Constant {extra_attr = 0 : i32, value = dense<-1> : tensor<1xi64>} : tensor<*xi64>

test/mlir/onnx/parse/com.microsoft.qdq_linear.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
// RUN: onnx-mlir --EmitONNXIR --useOnnxModelTypes=false --printIR %s | FileCheck %s
22
// Semi hand-written model.
33
// When converted to onnxtext, onnx-mlir didn't like the result.
4-
// CHECK-DAG: [[SCALE:%.+]] = onnx.Constant dense<-1.08420217E-19> : tensor<f32>
5-
// CHECK-DAG: [[ZERO_P:%.+]] = onnx.Constant dense<0> : tensor<i8>
4+
// CHECK-DAG: [[SCALE:%.+]] = onnx.Constant {onnx_node_name = "scale"} dense<-1.08420217E-19> : tensor<f32>
5+
// CHECK-DAG: [[ZERO_P:%.+]] = onnx.Constant {onnx_node_name = "zeropoint"} dense<0> : tensor<i8>
66
// CHECK: [[DQ:%.+]] = "onnx.DequantizeLinear"(%arg0, [[SCALE]], [[ZERO_P]]) {axis = 1 : si64, block_size = 0 : si64, onnx_node_name = "myDequantizeLinear"} : (tensor<1x64x112x112xi8>, tensor<f32>, tensor<i8>) -> tensor<1x64x112x112xf32>
77
// CHECK: [[RELU:%.+]] = "onnx.Relu"([[DQ]]) {onnx_node_name = "myrelu1Relu"} : (tensor<1x64x112x112xf32>) -> tensor<1x64x112x112xf32>
88
// CHECK: [[Q:%.+]] = "onnx.QuantizeLinear"([[RELU]], [[SCALE]], [[ZERO_P]]) {axis = 1 : si64, block_size = 0 : si64, onnx_node_name = "myQuantizeLinear_1", output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x64x112x112xf32>, tensor<f32>, tensor<i8>) -> tensor<1x64x112x112xi8>

0 commit comments

Comments
 (0)