Skip to content

Commit 5a200dd

Browse files
committed
check if scale , zp are all ScalarConstantTensor, only run dq, q around op removal if quantization is per tensor
1 parent 97d6ca4 commit 5a200dd

File tree

2 files changed

+24
-9
lines changed

2 files changed

+24
-9
lines changed

src/Dialect/ONNX/ONNXOps/OpHelper.cpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -315,16 +315,26 @@ ElementsAttr getElementAttributeFromONNXValue(Value value) {
315315
return mlir::dyn_cast<ElementsAttr>(constantOp.getValueAttr());
316316
return nullptr;
317317
}
318+
319+
// compare two ElementsAttr, except for their internal buffer size
318320
bool compareValueFromElementAttribute(
319321
ElementsAttr &attr1, ElementsAttr &attr2) {
320-
auto values1 = attr1.getValues<mlir::Attribute>();
321-
auto values2 = attr2.getValues<mlir::Attribute>();
322-
323-
if (values1.size() != values2.size()) {
322+
if (attr1.getType() != attr2.getType()) {
324323
return false;
325324
}
326-
return std::equal(values1.begin(), values1.end(), values2.begin());
325+
if (attr1.getNumElements() != attr2.getNumElements()) {
326+
return false;
327+
}
328+
auto it1 = attr1.getValues<mlir::Attribute>().begin();
329+
auto it2 = attr2.getValues<mlir::Attribute>().begin();
330+
for (; it1 != attr1.getValues<mlir::Attribute>().end(); ++it1, ++it2) {
331+
if (*it1 != *it2) {
332+
return false;
333+
}
334+
}
335+
return true;
327336
}
337+
328338
// Returns the ConstantOp which defines an MLIR Value or null.
329339
ONNXConstantOp getONNXConstantOp(Value value) {
330340
return mlir::dyn_cast_or_null<ONNXConstantOp>(value.getDefiningOp());

src/Dialect/ONNX/Transforms/QDQAroundOpOpt.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,17 +54,18 @@ class RemoveQDQAroundOpPattern : public OpRewritePattern<T> {
5454

5555
LogicalResult matchAndRewrite(
5656
T op, PatternRewriter &rewriter) const override {
57-
// if (llvm::isa<ONNXResizeOp>(op)) {
58-
if (auto resizeOp = dyn_cast<ONNXResizeOp>(op)) {
59-
// auto &resizeOp = llvm::cast<ONNXResizeOp>(op);
57+
if (llvm::isa<ONNXResizeOp>(op)) {
58+
auto &resizeOp = llvm::cast<ONNXResizeOp>(op);
6059
if (resizeOp.getMode() != "nearest") {
6160
return failure();
6261
}
6362
}
6463
InputAndOutput opIO = getDataInputOutput(op);
6564

6665
auto dqOp = opIO.input.getDefiningOp<ONNXDequantizeLinearOp>();
67-
if (!dqOp) {
66+
// Only run this pass if Quantizelization is on tensor
67+
if (!dqOp || !isScalarConstantTensor(dqOp.getXScale()) ||
68+
!isScalarConstantTensor(dqOp.getXZeroPoint())) {
6869
return failure();
6970
}
7071
if (!opIO.output.hasOneUse()) {
@@ -73,6 +74,10 @@ class RemoveQDQAroundOpPattern : public OpRewritePattern<T> {
7374

7475
Operation *firstOp = *(opIO.output.getUsers().begin());
7576
if (auto qOp = dyn_cast<ONNXQuantizeLinearOp>(firstOp)) {
77+
if (!isScalarConstantTensor(qOp.getYScale()) ||
78+
!isScalarConstantTensor(qOp.getYZeroPoint())) {
79+
return failure();
80+
}
7681
if (!isDequantQuantSame(dqOp, qOp))
7782
return failure();
7883

0 commit comments

Comments
 (0)