Skip to content

Commit 6634ed0

Browse files
committed
mlir gen works; lowering still buggy
1 parent 32bc7d3 commit 6634ed0

File tree

9 files changed

+327
-11
lines changed

9 files changed

+327
-11
lines changed

include/sql/SQLOps.td

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def SelectOp : SQL_Op<"select", [Pure]> {
2626
def ExecuteOp : SQL_Op<"execute", []> {
2727
let summary = "execute query";
2828

29-
let arguments = (ins Index:$handle);
29+
let arguments = (ins Index:$conn, Index:$command);
3030
let results = (outs Index:$result);
3131

3232
let hasFolder = 0;
@@ -40,17 +40,17 @@ def NumResultsOp : SQL_Op<"num_results", [Pure]> {
4040
let results = (outs Index:$result);
4141

4242
let hasFolder = 0;
43-
let hasCanonicalizer = 0;
43+
let hasCanonicalizer = 1;
4444
}
4545

46-
def ResultOp : SQL_Op<"get_result", [Pure]> {
47-
let summary = "get results of execution";
46+
def GetValueOp : SQL_Op<"get_value", [Pure]> {
47+
let summary = "get value of execution";
4848

49-
let arguments = (ins Index:$handle, StrAttr:$column, Index:$row);
49+
let arguments = (ins Index:$handle, Index:$column, Index:$row);
5050
let results = (outs AnyType:$result);
5151

5252
let hasFolder = 0;
53-
let hasCanonicalizer = 0;
53+
let hasCanonicalizer = 1;
5454
}
5555

5656
#endif // SQL_OPS

lib/sql/Ops.cpp

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,84 @@
3737

3838
using namespace mlir;
3939
using namespace sql;
40-
using namespace mlir::arith;
40+
using namespace mlir::arith;
41+
42+
43+
class GetValueOpTypeFix final : public OpRewritePattern<GetValueOp> {
44+
public:
45+
using OpRewritePattern<GetValueOp>::OpRewritePattern;
46+
47+
LogicalResult matchAndRewrite(GetValueOp op,
48+
PatternRewriter &rewriter) const override {
49+
50+
bool changed = false;
51+
52+
Value handle = op.getOperand(0);
53+
54+
if (!handle.getType().isa<IndexType>()) {
55+
handle = rewriter.create<IndexCastOp>(op.getLoc(),
56+
rewriter.getIndexType(), handle);
57+
changed = true;
58+
}
59+
Value row = op.getOperand(1);
60+
if (!row.getType().isa<IndexType>()) {
61+
row = rewriter.create<IndexCastOp>(op.getLoc(),
62+
rewriter.getIndexType(), row);
63+
changed = true;
64+
}
65+
Value column = op.getOperand(2);
66+
if (!column.getType().isa<IndexType>()) {
67+
column = rewriter.create<IndexCastOp>(op.getLoc(),
68+
rewriter.getIndexType(), column);
69+
changed = true;
70+
}
71+
72+
if (!changed) return failure();
73+
74+
rewriter.replaceOpWithNewOp<GetValueOp>(op, op.getType(), handle, row, column);
75+
76+
return success(changed);
77+
}
78+
};
79+
80+
void GetValueOp::getCanonicalizationPatterns(RewritePatternSet &results,
81+
MLIRContext *context) {
82+
results.insert<GetValueOpTypeFix>(context);
83+
}
84+
85+
86+
87+
class NumResultsOpTypeFix final : public OpRewritePattern<NumResultsOp> {
88+
public:
89+
using OpRewritePattern<NumResultsOp>::OpRewritePattern;
90+
91+
LogicalResult matchAndRewrite(NumResultsOp op,
92+
PatternRewriter &rewriter) const override {
93+
bool changed = false;
94+
Value handle = op->getOperand(0);
95+
96+
if (handle.getType().isa<IndexType>() && op->getResultTypes()[0].isa<IndexType>())
97+
return failure();
98+
99+
if (!handle.getType().isa<IndexType>()) {
100+
handle = rewriter.create<IndexCastOp>(op.getLoc(),
101+
rewriter.getIndexType(), handle);
102+
changed = true;
103+
}
104+
105+
mlir::Value res = rewriter.create<NumResultsOp>(op.getLoc(), rewriter.getIndexType(), handle);
106+
107+
if (op->getResultTypes()[0].isa<IndexType>()) {
108+
rewriter.replaceOp(op, res);
109+
} else {
110+
rewriter.replaceOpWithNewOp<IndexCastOp>(op, op->getResultTypes()[0], res);
111+
}
112+
113+
return success(changed);
114+
}
115+
};
116+
117+
void NumResultsOp::getCanonicalizationPatterns(RewritePatternSet &results,
118+
MLIRContext *context) {
119+
results.insert<NumResultsOpTypeFix>(context);
120+
}

lib/sql/Passes/SQLLower.cpp

Lines changed: 134 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ struct NumResultsOpLowering : public OpRewritePattern<sql::NumResultsOp> {
4848
auto module = loop->getParentOfType<ModuleOp>();
4949

5050
SymbolTableCollection symbolTable;
51-
symbolTable.getSymbolTable(loop);
51+
symbolTable.getSymbolTable(module);
5252

5353
// 1) make sure the postgres_getresult function is declared
5454
auto rowsfn = dyn_cast_or_null<func::FuncOp>(
@@ -80,7 +80,110 @@ struct NumResultsOpLowering : public OpRewritePattern<sql::NumResultsOp> {
8080
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
8181
loop, rewriter.getIndexType(), res2);
8282

83-
// 4) done
83+
return success();
84+
}
85+
};
86+
87+
88+
struct GetValueOpLowering : public OpRewritePattern<sql::GetValueOp> {
89+
using OpRewritePattern<sql::GetValueOp>::OpRewritePattern;
90+
91+
LogicalResult matchAndRewrite(sql::GetValueOp loop,
92+
PatternRewriter &rewriter) const final {
93+
auto module = loop->getParentOfType<ModuleOp>();
94+
95+
SymbolTableCollection symbolTable;
96+
symbolTable.getSymbolTable(module);
97+
98+
// 1) make sure the postgres_getresult function is declared
99+
auto valuefn = dyn_cast_or_null<func::FuncOp>(
100+
symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("PQgetvalue")));
101+
102+
auto atoifn = dyn_cast_or_null<func::FuncOp>(
103+
symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("atoi")));
104+
105+
// 2) convert the args to valid args to postgres_getresult abi
106+
Value handle = loop.getHandle();
107+
handle = rewriter.create<arith::IndexCastOp>(loop.getLoc(),
108+
rewriter.getI64Type(), handle);
109+
handle = rewriter.create<LLVM::IntToPtrOp>(
110+
loop.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), handle);
111+
112+
Value row = loop.getRow();
113+
Value column = loop.getColumn();
114+
115+
116+
// 3) call and replace
117+
Value args[] = {handle, row, column};
118+
119+
Value res =
120+
rewriter.create<mlir::func::CallOp>(loop.getLoc(), valuefn, args)
121+
->getResult(0);
122+
123+
Value args2[] = {res};
124+
125+
Value res2 =
126+
rewriter.create<mlir::func::CallOp>(loop.getLoc(), atoifn, args2)
127+
->getResult(0);
128+
129+
if (loop.getType() != res2.getType()) {
130+
if (loop.getType().isa<IndexType>())
131+
res2 = rewriter.create<arith::IndexCastOp>(loop.getLoc(),
132+
loop.getType(), res2);
133+
else if (auto IT = loop.getType().dyn_cast<IntegerType>()) {
134+
auto IT2 = res2.getType().dyn_cast<IntegerType>();
135+
if (IT.getWidth() < IT2.getWidth()) {
136+
res2 = rewriter.create<arith::TruncIOp>(loop.getLoc(),
137+
loop.getType(), res2);
138+
} else if (IT.getWidth() > IT2.getWidth()) {
139+
res2 = rewriter.create<arith::ExtUIOp>(loop.getLoc(),
140+
loop.getType(), res2);
141+
} else assert(0 && "illegal integer type conversion");
142+
} else {
143+
assert(0 && "illegal type conversion");
144+
}
145+
}
146+
rewriter.replaceOp(loop, res2);
147+
148+
return success();
149+
}
150+
};
151+
152+
153+
154+
struct ExecuteOpLowering : public OpRewritePattern<sql::ExecuteOp> {
155+
using OpRewritePattern<sql::ExecuteOp>::OpRewritePattern;
156+
157+
LogicalResult matchAndRewrite(sql::ExecuteOp loop,
158+
PatternRewriter &rewriter) const final {
159+
auto module = loop->getParentOfType<ModuleOp>();
160+
161+
SymbolTableCollection symbolTable;
162+
symbolTable.getSymbolTable(module);
163+
164+
// 1) make sure the postgres_getresult function is declared
165+
auto executefn = dyn_cast_or_null<func::FuncOp>(
166+
symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("PQexec")));
167+
168+
// 2) convert the args to valid args to postgres_getresult abi
169+
Value conn = loop.getConn();
170+
conn = rewriter.create<arith::IndexCastOp>(loop.getLoc(),
171+
rewriter.getI64Type(), conn);
172+
conn = rewriter.create<LLVM::IntToPtrOp>(
173+
loop.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), conn);
174+
175+
Value command = loop.getCommand();
176+
177+
// 3) call and replace
178+
Value args[] = {conn, command};
179+
180+
Value res =
181+
rewriter.create<mlir::func::CallOp>(loop.getLoc(), executefn, args)
182+
->getResult(0);
183+
184+
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
185+
loop, rewriter.getIndexType(), res);
186+
84187
return success();
85188
}
86189
};
@@ -103,9 +206,35 @@ void SQLLower::runOnOperation() {
103206
builder.getFunctionType(argtypes, rettypes));
104207
SymbolTable::setSymbolVisibility(fn, SymbolTable::Visibility::Private);
105208
}
209+
210+
if (!dyn_cast_or_null<func::FuncOp>(symbolTable.lookupSymbolIn(
211+
module, builder.getStringAttr("PQgetvalue")))) {
212+
mlir::Type argtypes[] = {
213+
LLVM::LLVMPointerType::get(builder.getI8Type()),
214+
LLVM::LLVMPointerType::get(builder.getI64Type()),
215+
LLVM::LLVMPointerType::get(builder.getI64Type())};
216+
mlir::Type rettypes[] = {LLVM::LLVMPointerType::get(builder.getI8Type())};
217+
218+
auto fn =
219+
builder.create<func::FuncOp>(module.getLoc(), "PQgetvalue",
220+
builder.getFunctionType(argtypes, rettypes));
221+
SymbolTable::setSymbolVisibility(fn, SymbolTable::Visibility::Private);
222+
}
223+
224+
// if (!dyn_cast_or_null<func::FuncOp>(
225+
// symbolTable.lookupSymbolIn(module, builder.getStringAttr("PQexec")))) {
226+
// mlir::Type argtypes[] = {LLVM::LLVMPointerType::get(builder.getI32Type()),
227+
// LLVM::LLVMPointerType::get(builder.getI8Type())};
228+
// mlir::Type rettypes[] = {LLVM::LLVMPointerType::get(builder.getI32Type())};
229+
230+
// auto fn = builder.create<func::FuncOp>(
231+
// module.getLoc(), "PQexec", builder.getFunctionType(argtypes, rettypes));
232+
// SymbolTable::setSymbolVisibility(fn, SymbolTable::Visibility::Private);
233+
// }
234+
106235
if (!dyn_cast_or_null<func::FuncOp>(
107236
symbolTable.lookupSymbolIn(module, builder.getStringAttr("atoi")))) {
108-
mlir::Type argtypes[] = {LLVM::LLVMPointerType::get(builder.getI8Type())};
237+
mlir::Type argtypes[] = {MemRefType::get({-1}, builder.getI8Type())};
109238

110239
// todo use data layout
111240
mlir::Type rettypes[] = {builder.getI64Type()};
@@ -117,6 +246,8 @@ void SQLLower::runOnOperation() {
117246

118247
RewritePatternSet patterns(&getContext());
119248
patterns.insert<NumResultsOpLowering>(&getContext());
249+
patterns.insert<GetValueOpLowering>(&getContext());
250+
// patterns.insert<ExecuteOpLowering>(&getContext());
120251

121252
GreedyRewriteConfig config;
122253
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),

lib/sql/Passes/SQLRaising.cpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,90 @@ struct PQcmdTuplesRaising : public OpRewritePattern<func::CallOp> {
7777
}
7878
};
7979

80+
81+
82+
struct PQgetvalueRaising : public OpRewritePattern<func::CallOp> {
83+
using OpRewritePattern<func::CallOp>::OpRewritePattern;
84+
85+
LogicalResult matchAndRewrite(func::CallOp call,
86+
PatternRewriter &rewriter) const final {
87+
if (call.getCallee() != "PQgetvalue") {
88+
return failure();
89+
}
90+
SymbolTableCollection symbolTable;
91+
symbolTable.getSymbolTable(call);
92+
auto module = call->getParentOfType<ModuleOp>();
93+
94+
// 2) convert the args to valid args to postgres_getresult abi
95+
Value handle = call.getArgOperands()[0];
96+
handle = rewriter.create<LLVM::PtrToIntOp>(
97+
call.getLoc(), rewriter.getIntegerType(64), handle);
98+
99+
handle = rewriter.create<arith::IndexCastOp>(call.getLoc(),
100+
rewriter.getIndexType(), handle);
101+
102+
Value row = call.getArgOperands()[1];
103+
Value column = call.getArgOperands()[2];
104+
105+
Value res = rewriter.create<sql::GetValueOp>(call.getLoc(), rewriter.getIndexType(), handle, row, column);
106+
// or Value res = rewriter.create<sql::GetValueOp>(call.getLoc(), rewriter.getIndexType(), {handle, row, column});
107+
108+
res = rewriter.create<arith::IndexCastOp>(call.getLoc(),
109+
rewriter.getI64Type(), res);
110+
111+
auto itoafn = dyn_cast_or_null<func::FuncOp>(
112+
symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("itoa")));
113+
114+
Value args2[] = {res};
115+
116+
rewriter.replaceOpWithNewOp<mlir::func::CallOp>(call, itoafn, args2);
117+
118+
// 4) done
119+
return success();
120+
}
121+
};
122+
123+
124+
struct PQexecRaising : public OpRewritePattern<func::CallOp> {
125+
using OpRewritePattern<func::CallOp>::OpRewritePattern;
126+
127+
LogicalResult matchAndRewrite(func::CallOp call,
128+
PatternRewriter &rewriter) const final {
129+
if (call.getCallee() != "PQexec") {
130+
return failure();
131+
}
132+
SymbolTableCollection symbolTable;
133+
symbolTable.getSymbolTable(call);
134+
auto module = call->getParentOfType<ModuleOp>();
135+
136+
// 2) convert the args to valid args to postgres_getresult abi
137+
Value conn = call.getArgOperands()[0];
138+
conn = rewriter.create<LLVM::PtrToIntOp>(
139+
call.getLoc(), rewriter.getIntegerType(64), conn);
140+
141+
conn = rewriter.create<arith::IndexCastOp>(call.getLoc(),
142+
rewriter.getIndexType(), conn);
143+
144+
Value command = call.getArgOperands()[1];
145+
command = rewriter.create<LLVM::PtrToIntOp>(
146+
call.getLoc(), rewriter.getIntegerType(64), command);
147+
148+
command = rewriter.create<arith::IndexCastOp>(call.getLoc(),
149+
rewriter.getIndexType(), command);
150+
151+
Value res = rewriter.create<sql::ExecuteOp>(call.getLoc(), rewriter.getIndexType(), conn, command);
152+
153+
res = rewriter.create<arith::IndexCastOp>(call.getLoc(),
154+
rewriter.getI64Type(), res);
155+
156+
rewriter.replaceOp(call, res);
157+
/// rewriter.replaceOpWithNewOp<mlir::func::CallOp>(call, itoafn, res);
158+
159+
// 4) done
160+
return success();
161+
}
162+
};
163+
80164
void SQLRaising::runOnOperation() {
81165
auto module = getOperation();
82166
SymbolTableCollection symbolTable;
@@ -98,6 +182,8 @@ void SQLRaising::runOnOperation() {
98182

99183
RewritePatternSet patterns(&getContext());
100184
patterns.insert<PQcmdTuplesRaising>(&getContext());
185+
patterns.insert<PQgetvalueRaising>(&getContext());
186+
// patterns.insert<PQexecRaising>(&getContext());
101187

102188
GreedyRewriteConfig config;
103189
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),

0 commit comments

Comments
 (0)