5
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
6
//
7
7
// ===----------------------------------------------------------------------===//
8
+ #include < algorithm>
9
+ #include < regex>
8
10
#include < string>
9
11
#include < vector>
10
- #include < regex>
11
- #include < algorithm>
12
12
13
13
#include " mlir/Dialect/Arith/IR/Arith.h"
14
14
#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
18
18
#include " mlir/IR/OpImplementation.h"
19
19
#include " mlir/Interfaces/SideEffectInterfaces.h"
20
20
21
+ #include " polygeist/Ops.h"
22
+ #include " sql/Parser.h"
21
23
#include " sql/SQLDialect.h"
22
24
#include " sql/SQLOps.h"
23
25
#include " sql/SQLTypes.h"
24
- #include " sql/Parser.h"
25
- #include " polygeist/Ops.h"
26
26
27
27
#define GET_OP_CLASSES
28
28
#include " sql/SQLOps.cpp.inc"
41
41
#include " llvm/ADT/SetVector.h"
42
42
#include " llvm/Support/Debug.h"
43
43
44
-
45
- #include " mlir/IR/Value.h"
44
+ #include " mlir/IR/Attributes.h"
46
45
#include " mlir/IR/Builders.h"
46
+ #include " mlir/IR/BuiltinTypes.h"
47
47
#include " mlir/IR/Location.h"
48
- #include " mlir/IR/Attributes .h"
48
+ #include " mlir/IR/Value .h"
49
49
#include " llvm/ADT/SmallVector.h"
50
- #include " mlir/IR/BuiltinTypes.h"
51
50
52
51
#define DEBUG_TYPE " sql"
53
52
54
53
using namespace mlir ;
55
54
using namespace sql ;
56
55
using namespace mlir ::arith;
57
56
58
-
59
57
class GetValueOpTypeFix final : public OpRewritePattern<GetValueOp> {
60
58
public:
61
59
using OpRewritePattern<GetValueOp>::OpRewritePattern;
@@ -67,38 +65,38 @@ class GetValueOpTypeFix final : public OpRewritePattern<GetValueOp> {
67
65
68
66
Value handle = op.getOperand (0 );
69
67
if (!handle.getType ().isa <IndexType>()) {
70
- handle = rewriter.create <IndexCastOp>(op.getLoc (),
71
- rewriter.getIndexType (), handle);
72
- changed = true ;
68
+ handle = rewriter.create <IndexCastOp>(op.getLoc (),
69
+ rewriter.getIndexType (), handle);
70
+ changed = true ;
73
71
}
74
72
Value row = op.getOperand (1 );
75
73
if (!row.getType ().isa <IndexType>()) {
76
- row = rewriter.create <IndexCastOp>(op.getLoc (),
77
- rewriter. getIndexType (), row);
78
- changed = true ;
74
+ row = rewriter.create <IndexCastOp>(op.getLoc (), rewriter. getIndexType (),
75
+ row);
76
+ changed = true ;
79
77
}
80
78
Value column = op.getOperand (2 );
81
79
if (!column.getType ().isa <IndexType>()) {
82
- column = rewriter.create <IndexCastOp>(op.getLoc (),
83
- rewriter.getIndexType (), column);
84
- changed = true ;
80
+ column = rewriter.create <IndexCastOp>(op.getLoc (),
81
+ rewriter.getIndexType (), column);
82
+ changed = true ;
85
83
}
86
84
87
- if (!changed) return failure ();
85
+ if (!changed)
86
+ return failure ();
88
87
89
- rewriter.replaceOpWithNewOp <GetValueOp>(op, op.getType (), handle, row, column);
88
+ rewriter.replaceOpWithNewOp <GetValueOp>(op, op.getType (), handle, row,
89
+ column);
90
90
91
91
return success (changed);
92
92
}
93
93
};
94
94
95
95
void GetValueOp::getCanonicalizationPatterns (RewritePatternSet &results,
96
- MLIRContext *context) {
96
+ MLIRContext *context) {
97
97
results.insert <GetValueOpTypeFix>(context);
98
98
}
99
99
100
-
101
-
102
100
class NumResultsOpTypeFix final : public OpRewritePattern<NumResultsOp> {
103
101
public:
104
102
using OpRewritePattern<NumResultsOp>::OpRewritePattern;
@@ -108,34 +106,35 @@ class NumResultsOpTypeFix final : public OpRewritePattern<NumResultsOp> {
108
106
bool changed = false ;
109
107
Value handle = op->getOperand (0 );
110
108
111
- if (handle.getType ().isa <IndexType>() && op->getResultTypes ()[0 ].isa <IndexType>())
112
- return failure ();
109
+ if (handle.getType ().isa <IndexType>() &&
110
+ op->getResultTypes ()[0 ].isa <IndexType>())
111
+ return failure ();
113
112
114
113
if (!handle.getType ().isa <IndexType>()) {
115
- handle = rewriter.create <IndexCastOp>(op.getLoc (),
116
- rewriter.getIndexType (), handle);
117
- changed = true ;
114
+ handle = rewriter.create <IndexCastOp>(op.getLoc (),
115
+ rewriter.getIndexType (), handle);
116
+ changed = true ;
118
117
}
119
118
120
- mlir::Value res = rewriter.create <NumResultsOp>(op.getLoc (), rewriter.getIndexType (), handle);
119
+ mlir::Value res = rewriter.create <NumResultsOp>(
120
+ op.getLoc (), rewriter.getIndexType (), handle);
121
121
122
122
if (op->getResultTypes ()[0 ].isa <IndexType>()) {
123
- rewriter.replaceOp (op, res);
123
+ rewriter.replaceOp (op, res);
124
124
} else {
125
- rewriter.replaceOpWithNewOp <IndexCastOp>(op, op->getResultTypes ()[0 ], res);
125
+ rewriter.replaceOpWithNewOp <IndexCastOp>(op, op->getResultTypes ()[0 ],
126
+ res);
126
127
}
127
128
128
129
return success (changed);
129
130
}
130
131
};
131
132
132
133
void NumResultsOp::getCanonicalizationPatterns (RewritePatternSet &results,
133
- MLIRContext *context) {
134
+ MLIRContext *context) {
134
135
results.insert <NumResultsOpTypeFix>(context);
135
136
}
136
137
137
-
138
-
139
138
// class ExecuteOpTypeFix final : public OpRewritePattern<ExecuteOp> {
140
139
// public:
141
140
// using OpRewritePattern<ExecuteOp>::OpRewritePattern;
@@ -147,39 +146,44 @@ void NumResultsOp::getCanonicalizationPatterns(RewritePatternSet &results,
147
146
// Value conn = op->getOperand(0);
148
147
// Value command = op->getOperand(1);
149
148
150
- // if (conn.getType().isa<IndexType>() && command.getType().isa<IndexType>() && op->getResultTypes()[0].isa<IndexType>())
149
+ // if (conn.getType().isa<IndexType>() && command.getType().isa<IndexType>()
150
+ // && op->getResultTypes()[0].isa<IndexType>())
151
151
// return failure();
152
152
153
153
// if (!conn.getType().isa<IndexType>()) {
154
154
// conn = rewriter.create<IndexCastOp>(op.getLoc(),
155
- // rewriter.getIndexType(), conn);
155
+ // rewriter.getIndexType(),
156
+ // conn);
156
157
// changed = true;
157
158
// }
158
159
// if (command.getType().isa<MemRefType>()) {
159
- // command = rewriter.create<polygeist::Memref2PointerOp>(op.getLoc(),
160
- // LLVM::LLVMPointerType::get(rewriter.getI8Type()), command);
160
+ // command = rewriter.create<polygeist::Memref2PointerOp>(op.getLoc(),
161
+ // LLVM::LLVMPointerType::get(rewriter.getI8Type()),
162
+ // command);
161
163
// changed = true;
162
164
// }
163
165
164
-
165
166
// if (command.getType().isa<LLVM::LLVMPointerType>()) {
166
- // command = rewriter.create<LLVM::PtrToIntOp>(op.getLoc(),
167
- // rewriter.getI64Type(), command);
167
+ // command = rewriter.create<LLVM::PtrToIntOp>(op.getLoc(),
168
+ // rewriter.getI64Type(),
169
+ // command);
168
170
// changed = true;
169
171
// }
170
172
// if (!command.getType().isa<IndexType>()) {
171
- // command = rewriter.create<IndexCastOp>(op.getLoc(),
172
- // rewriter.getIndexType(), command);
173
+ // command = rewriter.create<IndexCastOp>(op.getLoc(),
174
+ // rewriter.getIndexType(),
175
+ // command);
173
176
// changed = true;
174
177
// }
175
178
176
179
// if (!changed) return failure();
177
- // mlir::Value res = rewriter.create<ExecuteOp>(op.getLoc(), rewriter.getIndexType(), conn, command);
178
- // rewriter.replaceOp(op, res);
180
+ // mlir::Value res = rewriter.create<ExecuteOp>(op.getLoc(),
181
+ // rewriter.getIndexType(), conn, command); rewriter. replaceOp(op, res);
179
182
// // if (op->getResultTypes()[0].isa<IndexType>()) {
180
183
// // rewriter.replaceOp(op, res);
181
184
// // } else {
182
- // // rewriter.replaceOpWithNewOp<IndexCastOp>(op, op->getResultTypes()[0], res);
185
+ // // rewriter.replaceOpWithNewOp<IndexCastOp>(op,
186
+ // op->getResultTypes()[0], res);
183
187
// // }
184
188
// return success(changed);
185
189
// }
@@ -190,8 +194,7 @@ void NumResultsOp::getCanonicalizationPatterns(RewritePatternSet &results,
190
194
// results.insert<ExecuteOpTypeFix>(context);
191
195
// }
192
196
193
-
194
- template <typename T>
197
+ template <typename T>
195
198
class UnparsedOpInnerCast final : public OpRewritePattern<UnparsedOp> {
196
199
public:
197
200
using OpRewritePattern<UnparsedOp>::OpRewritePattern;
@@ -200,39 +203,91 @@ class UnparsedOpInnerCast final : public OpRewritePattern<UnparsedOp> {
200
203
PatternRewriter &rewriter) const override {
201
204
202
205
Value input = op->getOperand (0 );
203
-
206
+
204
207
auto cst = input.getDefiningOp <T>();
205
- if (!cst) return failure ();
208
+ if (!cst)
209
+ return failure ();
206
210
207
211
rewriter.replaceOpWithNewOp <UnparsedOp>(op, op.getType (), cst.getOperand ());
208
212
return success ();
209
213
}
210
214
};
211
215
212
216
void UnparsedOp::getCanonicalizationPatterns (RewritePatternSet &results,
213
- MLIRContext *context) {
214
- results.insert <UnparsedOpInnerCast<polygeist::Pointer2MemrefOp> >(context);
217
+ MLIRContext *context) {
218
+ results.insert <UnparsedOpInnerCast<polygeist::Pointer2MemrefOp>>(context);
215
219
}
216
220
217
-
218
- class SQLStringConcatOpCanonicalization final : public OpRewritePattern<SQLStringConcatOp> {
221
+ class SQLStringConcatOpCanonicalization final
222
+ : public OpRewritePattern<SQLStringConcatOp> {
219
223
public:
220
224
using OpRewritePattern<SQLStringConcatOp>::OpRewritePattern;
221
225
222
226
LogicalResult matchAndRewrite (SQLStringConcatOp op,
223
227
PatternRewriter &rewriter) const override {
224
-
225
- auto input1 = op->getOperand (0 ).getDefiningOp <SQLConstantStringOp>();
226
- auto input2 = op->getOperand (1 ).getDefiningOp <SQLConstantStringOp>();
227
-
228
- if (!input1 || !input2) return failure ();
229
-
230
- rewriter.replaceOpWithNewOp <SQLConstantStringOp>(op, op.getType (), (input1.getInput () + input2.getInput ()).str ());
231
- return success ();
228
+ // Whether we changed the state. If we make no simplifications we need to
229
+ // return failure otherwise we will infinite loop
230
+ bool changed = false ;
231
+ // Operands to the simplified concat
232
+ SmallVector<Value> operands;
233
+ // Constants that we will merge, "current running constant"
234
+ SmallVector<SQLConstantStringOp> constants;
235
+ for (auto op : op->getOperands ()) {
236
+ if (auto constOp = op.getDefiningOp <SQLConstantStringOp>()) {
237
+ constants.push_back (constOp);
238
+ continue ;
239
+ }
240
+ if (constants.size () != 0 ) {
241
+ if (constants.size () == 1 ) {
242
+ operands.push_back (constants[0 ]);
243
+ } else {
244
+ std::string nextStr;
245
+ changed = true ;
246
+ for (auto str : constants)
247
+ nextStr += str.getInput ().str ();
248
+
249
+ operands.push_back (rewriter.create <SQLConstantStringOp>(
250
+ op.getLoc (), MemRefType::get ({-1 }, rewriter.getI8Type ()), nextStr));
251
+ }
252
+ }
253
+ constants.clear ();
254
+ if (auto concat = op.getDefiningOp <SQLStringConcatOp>()) {
255
+ changed = true ;
256
+ for (auto op2 : concat->getOperands ())
257
+ operands.push_back (op2);
258
+ continue ;
259
+ }
260
+ operands.push_back (op);
261
+ }
262
+ if (constants.size () != 0 ) {
263
+ if (constants.size () == 1 ) {
264
+ operands.push_back (constants[0 ]);
265
+ } else {
266
+ std::string nextStr;
267
+ changed = true ;
268
+ for (auto str : constants)
269
+ nextStr = nextStr + str.getInput ().str ();
270
+ operands.push_back (rewriter.create <SQLConstantStringOp>(
271
+ op.getLoc (), MemRefType::get ({-1 }, rewriter.getI8Type ()), nextStr));
272
+ }
273
+ }
274
+ if (operands.size () == 0 ) {
275
+ rewriter.replaceOpWithNewOp <SQLConstantStringOp>(op, MemRefType::get ({-1 }, rewriter.getI8Type ()), " " );
276
+ return success ();
277
+ }
278
+ if (operands.size () == 1 ) {
279
+ rewriter.replaceOp (op, operands[0 ]);
280
+ return success ();
281
+ }
282
+ if (changed) {
283
+ rewriter.replaceOpWithNewOp <SQLStringConcatOp>(op, MemRefType::get ({-1 }, rewriter.getI8Type ()), operands);
284
+ return success ();
285
+ }
286
+ return failure ();
232
287
}
233
288
};
234
289
235
290
void SQLStringConcatOp::getCanonicalizationPatterns (RewritePatternSet &results,
236
- MLIRContext *context) {
291
+ MLIRContext *context) {
237
292
results.insert <SQLStringConcatOpCanonicalization>(context);
238
- }
293
+ }
0 commit comments