@@ -35,70 +35,6 @@ static ElementsAttr getElementAttributeFromConstant(Value val) {
3535 return nullptr ;
3636}
3737
38- static mlir::LogicalResult equalsDefaultIntegerAttr (
39- mlir::IntegerAttr ia, int64_t defaultValue) {
40- auto it = mlir::cast<mlir::IntegerType>(ia.getType ());
41- int64_t got = it.isUnsignedInteger ()
42- ? static_cast <int64_t >(ia.getValue ().getZExtValue ())
43- : ia.getValue ().getSExtValue ();
44- return (got == defaultValue) ? mlir::success () : mlir::failure ();
45- }
46-
47- static mlir::LogicalResult equalsDefaultIntElements (
48- mlir::ElementsAttr ea, int64_t defaultValue) {
49- auto st = mlir::dyn_cast<mlir::ShapedType>(ea.getType ());
50- if (!st)
51- return mlir::failure ();
52- mlir::Type et = st.getElementType ();
53- if (!et.isIntOrIndex ())
54- return mlir::failure ();
55- const bool isUnsigned = et.isa <mlir::IntegerType>() &&
56- et.cast <mlir::IntegerType>().isUnsignedInteger ();
57- if (ea.isSplat ()) {
58- llvm::APInt api = ea.getSplatValue <llvm::APInt>();
59- int64_t got = isUnsigned ? static_cast <int64_t >(api.getZExtValue ())
60- : api.getSExtValue ();
61- return (got == defaultValue) ? mlir::success () : mlir::failure ();
62- }
63- for (const llvm::APInt &api : ea.getValues <llvm::APInt>()) {
64- int64_t got = isUnsigned ? static_cast <int64_t >(api.getZExtValue ())
65- : api.getSExtValue ();
66- if (got != defaultValue)
67- return mlir::failure ();
68- }
69- return mlir::success ();
70- }
71-
72- static mlir::LogicalResult checkAttrAgainstDefault (
73- mlir::Attribute attr, int64_t defaultValue) {
74- if (!attr)
75- return mlir::failure ();
76- if (auto ia = mlir::dyn_cast<mlir::IntegerAttr>(attr))
77- return equalsDefaultIntegerAttr (ia, defaultValue);
78- if (auto ea = mlir::dyn_cast<mlir::ElementsAttr>(attr))
79- return equalsDefaultIntElements (ea, defaultValue);
80- return mlir::failure ();
81- }
82-
83- static mlir::LogicalResult checkIntegerAttributeEquals (mlir::Operation *op1,
84- mlir::Operation *op2, mlir::StringRef attrName, int64_t defaultValue) {
85- mlir::Attribute attr1 = op1->getAttr (attrName);
86- mlir::Attribute attr2 = op2->getAttr (attrName);
87- // Case 0: both missing => both implicitly default
88- if (!attr1 && !attr2)
89- return mlir::success ();
90- // Case 1: both present and identical
91- if (attr1 && attr2 && attr1 == attr2)
92- return mlir::success ();
93- // Case 2: one side missing => present side must equal default
94- if (!attr1)
95- return checkAttrAgainstDefault (attr2, defaultValue);
96- if (!attr2)
97- return checkAttrAgainstDefault (attr1, defaultValue);
98- // Case 3: both present but not identical
99- return mlir::failure ();
100- }
101-
10238// ===----------------------------------------------------------------------===//
10339// Pattern to remove QDQ pairs
10440// ===----------------------------------------------------------------------===//
@@ -107,23 +43,16 @@ struct FoldQDQPattern : public OpRewritePattern<ONNXQuantizeLinearOp> {
10743 using OpRewritePattern<ONNXQuantizeLinearOp>::OpRewritePattern;
10844 LogicalResult matchAndRewrite (
10945 ONNXQuantizeLinearOp qOp, PatternRewriter &rewriter) const override {
46+
11047 auto dqOp = qOp.getX ().getDefiningOp <ONNXDequantizeLinearOp>();
11148 if (!dqOp)
11249 return failure ();
11350
114- // 1. Check attributes with defaults (axis=1, block_size=0,
115- // saturate=1)
116- Operation *dqOperation = dqOp.getOperation ();
117- Operation *qOperation = qOp.getOperation ();
118-
119- if (failed (
120- checkIntegerAttributeEquals (dqOperation, qOperation, " axis" , 1 )) ||
121- failed (checkIntegerAttributeEquals (
122- dqOperation, qOperation, " block_size" , 0 )) ||
123- failed (checkIntegerAttributeEquals (
124- dqOperation, qOperation, " saturate" , 1 ))) {
51+ // 1. Check Attributes
52+ if (qOp.getAxis () != dqOp.getAxis ())
53+ return failure ();
54+ if (qOp.getBlockSize () != dqOp.getBlockSize ())
12555 return failure ();
126- }
12756
12857 // 2. Check zero-points
12958 auto zpAttr1 = getElementAttributeFromConstant (dqOp.getXZeroPoint ());
@@ -172,9 +101,9 @@ struct QDQOptONNXToONNXPass
172101 : public PassWrapper<QDQOptONNXToONNXPass, OperationPass<func::FuncOp>> {
173102
174103 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID (QDQOptONNXToONNXPass)
175- StringRef getArgument () const override { return " qdq -opt-onnx-to-onnx" ; }
104+ StringRef getArgument () const override { return " dqq -opt-onnx-to-onnx" ; }
176105 StringRef getDescription () const override {
177- return " Remove QDQ ops and surrounding QDQ if safe." ;
106+ return " Remove DqQ ops and surrounding DqQ if safe." ;
178107 }
179108
180109 void runOnOperation () override {
0 commit comments