Skip to content

Commit 0a26185

Browse files
committed
merge this before using new parser
1 parent 5249dbe commit 0a26185

File tree

4 files changed

+675
-460
lines changed

4 files changed

+675
-460
lines changed

include/sql/SQLOps.td

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ include "SQLDialect.td"
1414
include "SQLTypes.td"
1515

1616

17+
1718
class SQL_Op<string mnemonic, list<Trait> traits = []>
1819
: Op<SQL_Dialect, mnemonic, traits>;
1920

@@ -61,7 +62,7 @@ def CalcBoolOp: SQL_Op<"calc_bool", [Pure]> {
6162
def AndOp: SQL_Op<"and", [Pure]> {
6263
let summary = "and op";
6364

64-
let arguments = (ins Variadic<SQLBoolType>:$expr);
65+
let arguments = (ins SQLBoolType:$left, SQLBoolType:$right);
6566
let results = (outs SQLBoolType:$result);
6667

6768
let hasFolder = 0;
@@ -71,7 +72,7 @@ def AndOp: SQL_Op<"and", [Pure]> {
7172
def OrOp: SQL_Op<"or", [Pure]> {
7273
let summary = "or op";
7374

74-
let arguments = (ins Variadic<SQLBoolType>:$expr);
75+
let arguments = (ins SQLBoolType:$left, SQLBoolType:$right);
7576
let results = (outs SQLBoolType:$result);
7677

7778
let hasFolder = 0;
@@ -110,6 +111,17 @@ def SQLToStringOp : SQL_Op<"to_string", [Pure]> {
110111
}
111112

112113

114+
def SQLBoolToStringOp : SQL_Op<"bool_to_string", [Pure]> {
115+
let summary = "bool_to_string";
116+
117+
let arguments = (ins SQLBoolType:$expr);
118+
let results = (outs AnyType:$result);
119+
120+
let hasFolder = 0;
121+
let hasCanonicalizer = 0;
122+
}
123+
124+
113125
def SQLStringConcatOp : SQL_Op<"string_concat", [Pure]> {
114126
let summary = "string_concat";
115127

@@ -123,7 +135,6 @@ def SQLStringConcatOp : SQL_Op<"string_concat", [Pure]> {
123135
def ConstantBoolOp : SQL_Op <"constant_bool", [Pure]> {
124136
let summary = "constant_bool";
125137
let results = (outs SQLBoolType:$result);
126-
127138
}
128139

129140

@@ -132,9 +143,8 @@ def SelectOp : SQL_Op<"select", [Pure]> {
132143
// i need to specify the size of a Variadic?
133144
let arguments = (ins Variadic<SQLExprType>:$columns,
134145
SQLExprType:$table,
135-
// SQLBoolType:$where,
136-
BoolAttr:$selectAll
137-
IntAttr:$limit);
146+
SQLExprType:$where,
147+
SI64Attr:$limit);
138148
// attribute limit<int> if >= 0 then its the real thing, otherwise its infinity
139149
let results = (outs SQLExprType:$result);
140150

lib/sql/Ops.cpp

Lines changed: 119 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
8+
#include <algorithm>
9+
#include <regex>
810
#include <string>
911
#include <vector>
10-
#include <regex>
11-
#include <algorithm>
1212

1313
#include "mlir/Dialect/Arith/IR/Arith.h"
1414
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
@@ -18,11 +18,11 @@
1818
#include "mlir/IR/OpImplementation.h"
1919
#include "mlir/Interfaces/SideEffectInterfaces.h"
2020

21+
#include "polygeist/Ops.h"
22+
#include "sql/Parser.h"
2123
#include "sql/SQLDialect.h"
2224
#include "sql/SQLOps.h"
2325
#include "sql/SQLTypes.h"
24-
#include "sql/Parser.h"
25-
#include "polygeist/Ops.h"
2626

2727
#define GET_OP_CLASSES
2828
#include "sql/SQLOps.cpp.inc"
@@ -41,21 +41,19 @@
4141
#include "llvm/ADT/SetVector.h"
4242
#include "llvm/Support/Debug.h"
4343

44-
45-
#include "mlir/IR/Value.h"
44+
#include "mlir/IR/Attributes.h"
4645
#include "mlir/IR/Builders.h"
46+
#include "mlir/IR/BuiltinTypes.h"
4747
#include "mlir/IR/Location.h"
48-
#include "mlir/IR/Attributes.h"
48+
#include "mlir/IR/Value.h"
4949
#include "llvm/ADT/SmallVector.h"
50-
#include "mlir/IR/BuiltinTypes.h"
5150

5251
#define DEBUG_TYPE "sql"
5352

5453
using namespace mlir;
5554
using namespace sql;
5655
using namespace mlir::arith;
5756

58-
5957
class GetValueOpTypeFix final : public OpRewritePattern<GetValueOp> {
6058
public:
6159
using OpRewritePattern<GetValueOp>::OpRewritePattern;
@@ -67,38 +65,38 @@ class GetValueOpTypeFix final : public OpRewritePattern<GetValueOp> {
6765

6866
Value handle = op.getOperand(0);
6967
if (!handle.getType().isa<IndexType>()) {
70-
handle = rewriter.create<IndexCastOp>(op.getLoc(),
71-
rewriter.getIndexType(), handle);
72-
changed = true;
68+
handle = rewriter.create<IndexCastOp>(op.getLoc(),
69+
rewriter.getIndexType(), handle);
70+
changed = true;
7371
}
7472
Value row = op.getOperand(1);
7573
if (!row.getType().isa<IndexType>()) {
76-
row = rewriter.create<IndexCastOp>(op.getLoc(),
77-
rewriter.getIndexType(), row);
78-
changed = true;
74+
row = rewriter.create<IndexCastOp>(op.getLoc(), rewriter.getIndexType(),
75+
row);
76+
changed = true;
7977
}
8078
Value column = op.getOperand(2);
8179
if (!column.getType().isa<IndexType>()) {
82-
column = rewriter.create<IndexCastOp>(op.getLoc(),
83-
rewriter.getIndexType(), column);
84-
changed = true;
80+
column = rewriter.create<IndexCastOp>(op.getLoc(),
81+
rewriter.getIndexType(), column);
82+
changed = true;
8583
}
8684

87-
if (!changed) return failure();
85+
if (!changed)
86+
return failure();
8887

89-
rewriter.replaceOpWithNewOp<GetValueOp>(op, op.getType(), handle, row, column);
88+
rewriter.replaceOpWithNewOp<GetValueOp>(op, op.getType(), handle, row,
89+
column);
9090

9191
return success(changed);
9292
}
9393
};
9494

9595
void GetValueOp::getCanonicalizationPatterns(RewritePatternSet &results,
96-
MLIRContext *context) {
96+
MLIRContext *context) {
9797
results.insert<GetValueOpTypeFix>(context);
9898
}
9999

100-
101-
102100
class NumResultsOpTypeFix final : public OpRewritePattern<NumResultsOp> {
103101
public:
104102
using OpRewritePattern<NumResultsOp>::OpRewritePattern;
@@ -108,34 +106,35 @@ class NumResultsOpTypeFix final : public OpRewritePattern<NumResultsOp> {
108106
bool changed = false;
109107
Value handle = op->getOperand(0);
110108

111-
if (handle.getType().isa<IndexType>() && op->getResultTypes()[0].isa<IndexType>())
112-
return failure();
109+
if (handle.getType().isa<IndexType>() &&
110+
op->getResultTypes()[0].isa<IndexType>())
111+
return failure();
113112

114113
if (!handle.getType().isa<IndexType>()) {
115-
handle = rewriter.create<IndexCastOp>(op.getLoc(),
116-
rewriter.getIndexType(), handle);
117-
changed = true;
114+
handle = rewriter.create<IndexCastOp>(op.getLoc(),
115+
rewriter.getIndexType(), handle);
116+
changed = true;
118117
}
119118

120-
mlir::Value res = rewriter.create<NumResultsOp>(op.getLoc(), rewriter.getIndexType(), handle);
119+
mlir::Value res = rewriter.create<NumResultsOp>(
120+
op.getLoc(), rewriter.getIndexType(), handle);
121121

122122
if (op->getResultTypes()[0].isa<IndexType>()) {
123-
rewriter.replaceOp(op, res);
123+
rewriter.replaceOp(op, res);
124124
} else {
125-
rewriter.replaceOpWithNewOp<IndexCastOp>(op, op->getResultTypes()[0], res);
125+
rewriter.replaceOpWithNewOp<IndexCastOp>(op, op->getResultTypes()[0],
126+
res);
126127
}
127128

128129
return success(changed);
129130
}
130131
};
131132

132133
void NumResultsOp::getCanonicalizationPatterns(RewritePatternSet &results,
133-
MLIRContext *context) {
134+
MLIRContext *context) {
134135
results.insert<NumResultsOpTypeFix>(context);
135136
}
136137

137-
138-
139138
// class ExecuteOpTypeFix final : public OpRewritePattern<ExecuteOp> {
140139
// public:
141140
// using OpRewritePattern<ExecuteOp>::OpRewritePattern;
@@ -147,39 +146,44 @@ void NumResultsOp::getCanonicalizationPatterns(RewritePatternSet &results,
147146
// Value conn = op->getOperand(0);
148147
// Value command = op->getOperand(1);
149148

150-
// if (conn.getType().isa<IndexType>() && command.getType().isa<IndexType>() && op->getResultTypes()[0].isa<IndexType>())
149+
// if (conn.getType().isa<IndexType>() && command.getType().isa<IndexType>()
150+
// && op->getResultTypes()[0].isa<IndexType>())
151151
// return failure();
152152

153153
// if (!conn.getType().isa<IndexType>()) {
154154
// conn = rewriter.create<IndexCastOp>(op.getLoc(),
155-
// rewriter.getIndexType(), conn);
155+
// rewriter.getIndexType(),
156+
// conn);
156157
// changed = true;
157158
// }
158159
// if (command.getType().isa<MemRefType>()) {
159-
// command = rewriter.create<polygeist::Memref2PointerOp>(op.getLoc(),
160-
// LLVM::LLVMPointerType::get(rewriter.getI8Type()), command);
160+
// command = rewriter.create<polygeist::Memref2PointerOp>(op.getLoc(),
161+
// LLVM::LLVMPointerType::get(rewriter.getI8Type()),
162+
// command);
161163
// changed = true;
162164
// }
163165

164-
165166
// if (command.getType().isa<LLVM::LLVMPointerType>()) {
166-
// command = rewriter.create<LLVM::PtrToIntOp>(op.getLoc(),
167-
// rewriter.getI64Type(), command);
167+
// command = rewriter.create<LLVM::PtrToIntOp>(op.getLoc(),
168+
// rewriter.getI64Type(),
169+
// command);
168170
// changed = true;
169171
// }
170172
// if (!command.getType().isa<IndexType>()) {
171-
// command = rewriter.create<IndexCastOp>(op.getLoc(),
172-
// rewriter.getIndexType(), command);
173+
// command = rewriter.create<IndexCastOp>(op.getLoc(),
174+
// rewriter.getIndexType(),
175+
// command);
173176
// changed = true;
174177
// }
175178

176179
// if (!changed) return failure();
177-
// mlir::Value res = rewriter.create<ExecuteOp>(op.getLoc(), rewriter.getIndexType(), conn, command);
178-
// rewriter.replaceOp(op, res);
180+
// mlir::Value res = rewriter.create<ExecuteOp>(op.getLoc(),
181+
// rewriter.getIndexType(), conn, command); rewriter.replaceOp(op, res);
179182
// // if (op->getResultTypes()[0].isa<IndexType>()) {
180183
// // rewriter.replaceOp(op, res);
181184
// // } else {
182-
// // rewriter.replaceOpWithNewOp<IndexCastOp>(op, op->getResultTypes()[0], res);
185+
// // rewriter.replaceOpWithNewOp<IndexCastOp>(op,
186+
// op->getResultTypes()[0], res);
183187
// // }
184188
// return success(changed);
185189
// }
@@ -190,8 +194,7 @@ void NumResultsOp::getCanonicalizationPatterns(RewritePatternSet &results,
190194
// results.insert<ExecuteOpTypeFix>(context);
191195
// }
192196

193-
194-
template<typename T>
197+
template <typename T>
195198
class UnparsedOpInnerCast final : public OpRewritePattern<UnparsedOp> {
196199
public:
197200
using OpRewritePattern<UnparsedOp>::OpRewritePattern;
@@ -200,39 +203,91 @@ class UnparsedOpInnerCast final : public OpRewritePattern<UnparsedOp> {
200203
PatternRewriter &rewriter) const override {
201204

202205
Value input = op->getOperand(0);
203-
206+
204207
auto cst = input.getDefiningOp<T>();
205-
if (!cst) return failure();
208+
if (!cst)
209+
return failure();
206210

207211
rewriter.replaceOpWithNewOp<UnparsedOp>(op, op.getType(), cst.getOperand());
208212
return success();
209213
}
210214
};
211215

212216
void UnparsedOp::getCanonicalizationPatterns(RewritePatternSet &results,
213-
MLIRContext *context) {
214-
results.insert<UnparsedOpInnerCast<polygeist::Pointer2MemrefOp> >(context);
217+
MLIRContext *context) {
218+
results.insert<UnparsedOpInnerCast<polygeist::Pointer2MemrefOp>>(context);
215219
}
216220

217-
218-
class SQLStringConcatOpCanonicalization final : public OpRewritePattern<SQLStringConcatOp> {
221+
class SQLStringConcatOpCanonicalization final
222+
: public OpRewritePattern<SQLStringConcatOp> {
219223
public:
220224
using OpRewritePattern<SQLStringConcatOp>::OpRewritePattern;
221225

222226
LogicalResult matchAndRewrite(SQLStringConcatOp op,
223227
PatternRewriter &rewriter) const override {
224-
225-
auto input1 = op->getOperand(0).getDefiningOp<SQLConstantStringOp>();
226-
auto input2 = op->getOperand(1).getDefiningOp<SQLConstantStringOp>();
227-
228-
if (!input1 || !input2) return failure();
229-
230-
rewriter.replaceOpWithNewOp<SQLConstantStringOp>(op, op.getType(), (input1.getInput() + input2.getInput()).str());
231-
return success();
228+
// Whether we changed the state. If we make no simplifications we need to
229+
// return failure otherwise we will infinite loop
230+
bool changed = false;
231+
// Operands to the simplified concat
232+
SmallVector<Value> operands;
233+
// Constants that we will merge, "current running constant"
234+
SmallVector<SQLConstantStringOp> constants;
235+
for (auto op : op->getOperands()) {
236+
if (auto constOp = op.getDefiningOp<SQLConstantStringOp>()) {
237+
constants.push_back(constOp);
238+
continue;
239+
}
240+
if (constants.size() != 0) {
241+
if (constants.size() == 1) {
242+
operands.push_back(constants[0]);
243+
} else {
244+
std::string nextStr;
245+
changed = true;
246+
for (auto str : constants)
247+
nextStr += str.getInput().str();
248+
249+
operands.push_back(rewriter.create<SQLConstantStringOp>(
250+
op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), nextStr));
251+
}
252+
}
253+
constants.clear();
254+
if (auto concat = op.getDefiningOp<SQLStringConcatOp>()) {
255+
changed = true;
256+
for (auto op2 : concat->getOperands())
257+
operands.push_back(op2);
258+
continue;
259+
}
260+
operands.push_back(op);
261+
}
262+
if (constants.size() != 0) {
263+
if (constants.size() == 1) {
264+
operands.push_back(constants[0]);
265+
} else {
266+
std::string nextStr;
267+
changed = true;
268+
for (auto str : constants)
269+
nextStr = nextStr + str.getInput().str();
270+
operands.push_back(rewriter.create<SQLConstantStringOp>(
271+
op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), nextStr));
272+
}
273+
}
274+
if (operands.size() == 0) {
275+
rewriter.replaceOpWithNewOp<SQLConstantStringOp>(op, MemRefType::get({-1}, rewriter.getI8Type()), "");
276+
return success();
277+
}
278+
if (operands.size() == 1) {
279+
rewriter.replaceOp(op, operands[0]);
280+
return success();
281+
}
282+
if (changed) {
283+
rewriter.replaceOpWithNewOp<SQLStringConcatOp>(op, MemRefType::get({-1}, rewriter.getI8Type()), operands);
284+
return success();
285+
}
286+
return failure();
232287
}
233288
};
234289

235290
void SQLStringConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
236-
MLIRContext *context) {
291+
MLIRContext *context) {
237292
results.insert<SQLStringConcatOpCanonicalization>(context);
238-
}
293+
}

0 commit comments

Comments
 (0)