File tree Expand file tree Collapse file tree 2 files changed +24
-9
lines changed Expand file tree Collapse file tree 2 files changed +24
-9
lines changed Original file line number Diff line number Diff line change @@ -315,16 +315,26 @@ ElementsAttr getElementAttributeFromONNXValue(Value value) {
315315 return mlir::dyn_cast<ElementsAttr>(constantOp.getValueAttr ());
316316 return nullptr ;
317317}
318+
319+ // compare two ElementsAttr, except for their internal buffer size
318320bool compareValueFromElementAttribute (
319321 ElementsAttr &attr1, ElementsAttr &attr2) {
320- auto values1 = attr1.getValues <mlir::Attribute>();
321- auto values2 = attr2.getValues <mlir::Attribute>();
322-
323- if (values1.size () != values2.size ()) {
322+ if (attr1.getType () != attr2.getType ()) {
324323 return false ;
325324 }
326- return std::equal (values1.begin (), values1.end (), values2.begin ());
325+ if (attr1.getNumElements () != attr2.getNumElements ()) {
326+ return false ;
327+ }
328+ auto it1 = attr1.getValues <mlir::Attribute>().begin ();
329+ auto it2 = attr2.getValues <mlir::Attribute>().begin ();
330+ for (; it1 != attr1.getValues <mlir::Attribute>().end (); ++it1, ++it2) {
331+ if (*it1 != *it2) {
332+ return false ;
333+ }
334+ }
335+ return true ;
327336}
337+
328338// Returns the ConstantOp which defines an MLIR Value or null.
329339ONNXConstantOp getONNXConstantOp (Value value) {
330340 return mlir::dyn_cast_or_null<ONNXConstantOp>(value.getDefiningOp ());
Original file line number Diff line number Diff line change @@ -54,17 +54,18 @@ class RemoveQDQAroundOpPattern : public OpRewritePattern<T> {
5454
5555 LogicalResult matchAndRewrite (
5656 T op, PatternRewriter &rewriter) const override {
57- // if (llvm::isa<ONNXResizeOp>(op)) {
58- if (auto resizeOp = dyn_cast<ONNXResizeOp>(op)) {
59- // auto &resizeOp = llvm::cast<ONNXResizeOp>(op);
57+ if (llvm::isa<ONNXResizeOp>(op)) {
58+ auto &resizeOp = llvm::cast<ONNXResizeOp>(op);
6059 if (resizeOp.getMode () != " nearest" ) {
6160 return failure ();
6261 }
6362 }
6463 InputAndOutput opIO = getDataInputOutput (op);
6564
6665 auto dqOp = opIO.input .getDefiningOp <ONNXDequantizeLinearOp>();
67- if (!dqOp) {
66+ // Only run this pass if Quantizelization is on tensor
67+ if (!dqOp || !isScalarConstantTensor (dqOp.getXScale ()) ||
68+ !isScalarConstantTensor (dqOp.getXZeroPoint ())) {
6869 return failure ();
6970 }
7071 if (!opIO.output .hasOneUse ()) {
@@ -73,6 +74,10 @@ class RemoveQDQAroundOpPattern : public OpRewritePattern<T> {
7374
7475 Operation *firstOp = *(opIO.output .getUsers ().begin ());
7576 if (auto qOp = dyn_cast<ONNXQuantizeLinearOp>(firstOp)) {
77+ if (!isScalarConstantTensor (qOp.getYScale ()) ||
78+ !isScalarConstantTensor (qOp.getYZeroPoint ())) {
79+ return failure ();
80+ }
7681 if (!isDequantQuantSame (dqOp, qOp))
7782 return failure ();
7883
You can’t perform that action at this time.
0 commit comments