@@ -291,6 +291,13 @@ static LogicalResult tryRemoveQThenDQChain(
291291 return success ();
292292}
293293
294+ static bool isValuePreservingOp (mlir::Operation *op) {
295+ if (!op)
296+ return false ;
297+ return isa<mlir::ONNXIdentityOp, mlir::ONNXReshapeOp, mlir::ONNXSqueezeOp,
298+ mlir::ONNXUnsqueezeOp, mlir::ONNXTransposeOp>(op);
299+ }
300+
294301template <typename BinOp>
295302struct FoldBinaryThroughQDQ : public OpRewritePattern <BinOp> {
296303 using OpRewritePattern<BinOp>::OpRewritePattern;
@@ -326,8 +333,7 @@ struct FoldBinaryThroughQDQ : public OpRewritePattern<BinOp> {
326333 // Case 2: The input to the DQ op comes from a chain whose input is a
327334 // constant.
328335 else if (auto intermediateOp = dq1.getX ().getDefiningOp ()) {
329- if (isa<mlir::ONNXIdentityOp, mlir::ONNXReshapeOp, mlir::ONNXSqueezeOp,
330- mlir::ONNXUnsqueezeOp>(intermediateOp)) {
336+ if (isValuePreservingOp (intermediateOp)) {
331337 if (auto constOp =
332338 intermediateOp->getOperand (0 ).getDefiningOp <ONNXConstantOp>()) {
333339 constantDqOp = dq1;
@@ -336,8 +342,7 @@ struct FoldBinaryThroughQDQ : public OpRewritePattern<BinOp> {
336342 }
337343 }
338344 } else if (auto intermediateOp = dq2.getX ().getDefiningOp ()) {
339- if (isa<mlir::ONNXIdentityOp, mlir::ONNXReshapeOp, mlir::ONNXSqueezeOp,
340- mlir::ONNXUnsqueezeOp>(intermediateOp)) {
345+ if (isValuePreservingOp (intermediateOp)) {
341346 if (auto constOp =
342347 intermediateOp->getOperand (0 ).getDefiningOp <ONNXConstantOp>()) {
343348 constantDqOp = dq2;
0 commit comments