Skip to content

Commit b42f67f

Browse files
author
sushmita
committed
added utility isValuePreserving
1 parent 5a82ea7 commit b42f67f

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

src/Dialect/ONNX/Transforms/DQBinaryQOpt.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,13 @@ static LogicalResult tryRemoveQThenDQChain(
291291
return success();
292292
}
293293

294+
static bool isValuePreservingOp(mlir::Operation *op) {
295+
if (!op)
296+
return false;
297+
return isa<mlir::ONNXIdentityOp, mlir::ONNXReshapeOp, mlir::ONNXSqueezeOp,
298+
mlir::ONNXUnsqueezeOp, mlir::ONNXTransposeOp>(op);
299+
}
300+
294301
template <typename BinOp>
295302
struct FoldBinaryThroughQDQ : public OpRewritePattern<BinOp> {
296303
using OpRewritePattern<BinOp>::OpRewritePattern;
@@ -326,8 +333,7 @@ struct FoldBinaryThroughQDQ : public OpRewritePattern<BinOp> {
326333
// Case 2: The input to the DQ op comes from a chain whose input is a
327334
// constant.
328335
else if (auto intermediateOp = dq1.getX().getDefiningOp()) {
329-
if (isa<mlir::ONNXIdentityOp, mlir::ONNXReshapeOp, mlir::ONNXSqueezeOp,
330-
mlir::ONNXUnsqueezeOp>(intermediateOp)) {
336+
if (isValuePreservingOp(intermediateOp)) {
331337
if (auto constOp =
332338
intermediateOp->getOperand(0).getDefiningOp<ONNXConstantOp>()) {
333339
constantDqOp = dq1;
@@ -336,8 +342,7 @@ struct FoldBinaryThroughQDQ : public OpRewritePattern<BinOp> {
336342
}
337343
}
338344
} else if (auto intermediateOp = dq2.getX().getDefiningOp()) {
339-
if (isa<mlir::ONNXIdentityOp, mlir::ONNXReshapeOp, mlir::ONNXSqueezeOp,
340-
mlir::ONNXUnsqueezeOp>(intermediateOp)) {
345+
if (isValuePreservingOp(intermediateOp)) {
341346
if (auto constOp =
342347
intermediateOp->getOperand(0).getDefiningOp<ONNXConstantOp>()) {
343348
constantDqOp = dq2;

0 commit comments

Comments
 (0)