1+ // ===- QDQAroundOpOpt.cpp - Remove DQ, Q operations around data movement ops
2+ // --------*- C++ -*-===//
3+ //
4+ // (c) Copyright 2022 - 2025 Advanced Micro Devices, Inc. All Rights Reserved.
5+ //
6+ // ===----------------------------------------------------------------------===//
7+
8+ #include < cmath>
9+ #include < mlir/IR/IRMapping.h>
10+ #include < mlir/IR/Operation.h>
11+ #include < mlir/IR/PatternMatch.h>
12+ #include < mlir/Pass/Pass.h>
13+ #include < mlir/Transforms/DialectConversion.h>
14+ #include < mlir/Transforms/GreedyPatternRewriteDriver.h>
15+ #include < src/Dialect/ONNX/ONNXOps.hpp>
16+ #include < src/Dialect/ONNX/ONNXOps/OpHelper.hpp>
17+
18+ using namespace mlir ;
19+ using namespace onnx_mlir ;
20+ struct InputAndOutput {
21+ Value input;
22+ Value output;
23+ };
24+
25+ InputAndOutput getDataInputOutput (ONNXTransposeOp transposeOp) {
26+ return {transposeOp.getData (), transposeOp.getTransposed ()};
27+ }
28+ InputAndOutput getDataInputOutput (ONNXUnsqueezeOp unsqueezeOp) {
29+ return {unsqueezeOp.getData (), unsqueezeOp.getExpanded ()};
30+ }
31+ InputAndOutput getDataInputOutput (ONNXSqueezeOp squeezeOp) {
32+ return {squeezeOp.getData (), squeezeOp.getSqueezed ()};
33+ }
34+ InputAndOutput getDataInputOutput (ONNXReshapeOp reshapeOp) {
35+ return {reshapeOp.getData (), reshapeOp.getReshaped ()};
36+ }
37+ InputAndOutput getDataInputOutput (ONNXGatherOp gatherOp) {
38+ return {gatherOp.getData (), gatherOp.getOutput ()};
39+ }
40+ InputAndOutput getDataInputOutput (ONNXSliceOp sliceOp) {
41+ return {sliceOp.getData (), sliceOp.getOutput ()};
42+ }
43+ InputAndOutput getDataInputOutput (ONNXResizeOp resizeOp) {
44+ return {resizeOp.getX (), resizeOp.getY ()};
45+ }
46+ InputAndOutput getDataInputOutput (ONNXFlattenOp flattenOp) {
47+ return {flattenOp.getInput (), flattenOp.getOutput ()};
48+ }
49+ namespace {
50+ template <typename T>
51+ class RemoveQDQAroundOpPattern : public OpRewritePattern <T> {
52+ public:
53+ using OpRewritePattern<T>::OpRewritePattern;
54+
55+ LogicalResult matchAndRewrite (
56+ T op, PatternRewriter &rewriter) const override {
57+ if (llvm::isa<ONNXResizeOp>(op)) {
58+ auto &resizeOp = llvm::cast<ONNXResizeOp>(op);
59+ if (resizeOp.getMode () != " nearest" ) {
60+ return failure ();
61+ }
62+ }
63+ InputAndOutput opIO = getDataInputOutput (op);
64+
65+ auto dqOp = opIO.input .getDefiningOp <ONNXDequantizeLinearOp>();
66+ // Only run this pass if Quantizelization is on tensor
67+ if (!dqOp || !isScalarConstantTensor (dqOp.getXScale ()) ||
68+ !isScalarConstantTensor (dqOp.getXZeroPoint ())) {
69+ return failure ();
70+ }
71+ if (!opIO.output .hasOneUse ()) {
72+ return failure ();
73+ }
74+
75+ Operation *firstOp = *(opIO.output .getUsers ().begin ());
76+ if (auto qOp = dyn_cast<ONNXQuantizeLinearOp>(firstOp)) {
77+ if (!isScalarConstantTensor (qOp.getYScale ()) ||
78+ !isScalarConstantTensor (qOp.getYZeroPoint ())) {
79+ return failure ();
80+ }
81+ if (!isDequantQuantSame (dqOp, qOp))
82+ return failure ();
83+
84+ // Map dqOp inputs to dqOp's inputs
85+ IRMapping irMapping;
86+ irMapping.map (dqOp, dqOp.getX ());
87+
88+ SmallVector<Value> newInputs;
89+ transform (op->getOperands (), std::back_inserter (newInputs),
90+ [&](Value operand) { return irMapping.lookupOrDefault (operand); });
91+
92+ auto newOp =
93+ rewriter.create <T>(op.getLoc (), TypeRange{qOp.getResult ().getType ()},
94+ ValueRange{newInputs}, op->getAttrs ());
95+ rewriter.replaceOp (qOp, newOp.getResult ());
96+ return success ();
97+ }
98+ };
99+ };
100+ struct QDQAroundOpOptONNXToONNXPass
101+ : public PassWrapper<QDQAroundOpOptONNXToONNXPass,
102+ OperationPass<func::FuncOp>> {
103+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID (QDQAroundOpOptONNXToONNXPass)
104+ StringRef getArgument () const override {
105+ return " qdq-around-op-opt-onnx-to-onnx" ;
106+ }
107+ StringRef getDescription () const override {
108+ return " Remove QDQ around ops if safe." ;
109+ }
110+
111+ void runOnOperation () override {
112+ auto function = getOperation ();
113+ auto *ctx = &getContext ();
114+ RewritePatternSet patterns (ctx);
115+ // ONNXReduceSumOp is expecting high precision value, it failed to compile
116+ // during applying this pass, so for now there is no dq, q removal around
117+ // ReduceSum
118+ patterns.add <RemoveQDQAroundOpPattern<ONNXTransposeOp>,
119+ RemoveQDQAroundOpPattern<ONNXUnsqueezeOp>,
120+ RemoveQDQAroundOpPattern<ONNXSqueezeOp>,
121+ RemoveQDQAroundOpPattern<ONNXReshapeOp>,
122+ RemoveQDQAroundOpPattern<ONNXResizeOp>,
123+ RemoveQDQAroundOpPattern<ONNXGatherOp>,
124+ RemoveQDQAroundOpPattern<ONNXSliceOp>,
125+ RemoveQDQAroundOpPattern<ONNXFlattenOp>>(patterns.getContext ());
126+ if (failed (applyPatternsGreedily (function, std::move (patterns))))
127+ signalPassFailure ();
128+ }
129+ };
130+ } // namespace
131+
132+ namespace onnx_mlir {
133+ std::unique_ptr<mlir::Pass> createQDQAroundOpOptONNXToONNXPass () {
134+ return std::make_unique<QDQAroundOpOptONNXToONNXPass>();
135+ }
136+ } // namespace onnx_mlir
0 commit comments