Skip to content

Commit 548d94a

Browse files
committed
lowering still buggy
1 parent 6634ed0 commit 548d94a

File tree

2 files changed

+127
-71
lines changed

2 files changed

+127
-71
lines changed

lib/sql/Ops.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir/IR/Builders.h"
1414
#include "mlir/IR/OpImplementation.h"
1515
#include "mlir/Interfaces/SideEffectInterfaces.h"
16+
1617
#include "sql/SQLDialect.h"
1718
#include "sql/SQLOps.h"
1819

@@ -29,6 +30,7 @@
2930
#include "mlir/IR/Dominance.h"
3031
#include "mlir/IR/IntegerSet.h"
3132
#include "mlir/Transforms/SideEffectUtils.h"
33+
// #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
3234

3335
#include "llvm/ADT/SetVector.h"
3436
#include "llvm/Support/Debug.h"
@@ -94,7 +96,7 @@ class NumResultsOpTypeFix final : public OpRewritePattern<NumResultsOp> {
9496
Value handle = op->getOperand(0);
9597

9698
if (handle.getType().isa<IndexType>() && op->getResultTypes()[0].isa<IndexType>())
97-
return failure();
99+
return failure();
98100

99101
if (!handle.getType().isa<IndexType>()) {
100102
handle = rewriter.create<IndexCastOp>(op.getLoc(),
@@ -105,9 +107,9 @@ class NumResultsOpTypeFix final : public OpRewritePattern<NumResultsOp> {
105107
mlir::Value res = rewriter.create<NumResultsOp>(op.getLoc(), rewriter.getIndexType(), handle);
106108

107109
if (op->getResultTypes()[0].isa<IndexType>()) {
108-
rewriter.replaceOp(op, res);
110+
rewriter.replaceOp(op, res);
109111
} else {
110-
rewriter.replaceOpWithNewOp<IndexCastOp>(op, op->getResultTypes()[0], res);
112+
rewriter.replaceOpWithNewOp<IndexCastOp>(op, op->getResultTypes()[0], res);
111113
}
112114

113115
return success(changed);

lib/sql/Passes/SQLLower.cpp

Lines changed: 122 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -40,45 +40,98 @@ struct SQLLower : public SQLLowerBase<SQLLower> {
4040

4141
} // end anonymous namespace
4242

43+
44+
struct ExecuteOpLowering : public OpRewritePattern<sql::ExecuteOp> {
45+
using OpRewritePattern<sql::ExecuteOp>::OpRewritePattern;
46+
47+
LogicalResult matchAndRewrite(sql::ExecuteOp op,
48+
PatternRewriter &rewriter) const final {
49+
auto module = op->getParentOfType<ModuleOp>();
50+
51+
SymbolTableCollection symbolTable;
52+
symbolTable.getSymbolTable(module);
53+
54+
// 1) make sure the postgres_getresult function is declared
55+
auto execfn = dyn_cast_or_null<func::FuncOp>(
56+
symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("PQexec")));
57+
58+
auto atoifn = dyn_cast_or_null<func::FuncOp>(
59+
symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("atoi")));
60+
61+
// 2) convert the args to valid args to postgres_getresult abi
62+
Value conn = op.getConn();
63+
conn = rewriter.create<arith::IndexCastOp>(op.getLoc(),
64+
rewriter.getI64Type(), conn);
65+
conn = rewriter.create<LLVM::IntToPtrOp>(
66+
op.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), conn);
67+
68+
Value command = op.getCommand();
69+
command = rewriter.create<arith::IndexCastOp>(op.getLoc(),
70+
rewriter.getI64Type(), command);
71+
command = rewriter.create<LLVM::IntToPtrOp>(
72+
op.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), command);
73+
74+
// 3) call and replace
75+
Value args[] = {conn, command};
76+
77+
Value res =
78+
rewriter.create<mlir::func::CallOp>(op.getLoc(), execfn, args)
79+
->getResult(0);
80+
81+
Value args2[] = {res};
82+
83+
Value res2 =
84+
rewriter.create<mlir::func::CallOp>(op.getLoc(), atoifn, args2)
85+
->getResult(0);
86+
87+
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
88+
op, rewriter.getIndexType(), res2);
89+
90+
return success();
91+
}
92+
};
93+
94+
4395
struct NumResultsOpLowering : public OpRewritePattern<sql::NumResultsOp> {
4496
using OpRewritePattern<sql::NumResultsOp>::OpRewritePattern;
4597

46-
LogicalResult matchAndRewrite(sql::NumResultsOp loop,
98+
LogicalResult matchAndRewrite(sql::NumResultsOp op,
4799
PatternRewriter &rewriter) const final {
48-
auto module = loop->getParentOfType<ModuleOp>();
100+
auto module = op->getParentOfType<ModuleOp>();
49101

50102
SymbolTableCollection symbolTable;
51103
symbolTable.getSymbolTable(module);
52104

53105
// 1) make sure the postgres_getresult function is declared
54106
auto rowsfn = dyn_cast_or_null<func::FuncOp>(
55-
symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("PQcmdTuples")));
107+
symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("PQntuples")));
56108

57109
auto atoifn = dyn_cast_or_null<func::FuncOp>(
58110
symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("atoi")));
59111

60112
// 2) convert the args to valid args to postgres_getresult abi
61-
Value arg = loop.getHandle();
62-
arg = rewriter.create<arith::IndexCastOp>(loop.getLoc(),
113+
Value arg = op.getHandle();
114+
arg = rewriter.create<arith::IndexCastOp>(op.getLoc(),
63115
rewriter.getI64Type(), arg);
116+
64117
arg = rewriter.create<LLVM::IntToPtrOp>(
65-
loop.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), arg);
118+
op.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), arg);
66119

67120
// 3) call and replace
68121
Value args[] = {arg};
69122

70123
Value res =
71-
rewriter.create<mlir::func::CallOp>(loop.getLoc(), rowsfn, args)
124+
rewriter.create<mlir::func::CallOp>(op.getLoc(), rowsfn, args)
72125
->getResult(0);
73126

74127
Value args2[] = {res};
75128

76129
Value res2 =
77-
rewriter.create<mlir::func::CallOp>(loop.getLoc(), atoifn, args2)
130+
rewriter.create<mlir::func::CallOp>(op.getLoc(), atoifn, args2)
78131
->getResult(0);
79132

80133
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
81-
loop, rewriter.getIndexType(), res2);
134+
op, rewriter.getIndexType(), res2);
82135

83136
return success();
84137
}
@@ -88,9 +141,9 @@ struct NumResultsOpLowering : public OpRewritePattern<sql::NumResultsOp> {
88141
struct GetValueOpLowering : public OpRewritePattern<sql::GetValueOp> {
89142
using OpRewritePattern<sql::GetValueOp>::OpRewritePattern;
90143

91-
LogicalResult matchAndRewrite(sql::GetValueOp loop,
144+
LogicalResult matchAndRewrite(sql::GetValueOp op,
92145
PatternRewriter &rewriter) const final {
93-
auto module = loop->getParentOfType<ModuleOp>();
146+
auto module = op->getParentOfType<ModuleOp>();
94147

95148
SymbolTableCollection symbolTable;
96149
symbolTable.getSymbolTable(module);
@@ -103,90 +156,90 @@ struct GetValueOpLowering : public OpRewritePattern<sql::GetValueOp> {
103156
symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("atoi")));
104157

105158
// 2) convert the args to valid args to postgres_getresult abi
106-
Value handle = loop.getHandle();
107-
handle = rewriter.create<arith::IndexCastOp>(loop.getLoc(),
159+
Value handle = op.getHandle();
160+
handle = rewriter.create<arith::IndexCastOp>(op.getLoc(),
108161
rewriter.getI64Type(), handle);
109162
handle = rewriter.create<LLVM::IntToPtrOp>(
110-
loop.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), handle);
163+
op.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), handle);
111164

112-
Value row = loop.getRow();
113-
Value column = loop.getColumn();
165+
Value row = op.getRow();
166+
Value column = op.getColumn();
114167

115168

116169
// 3) call and replace
117170
Value args[] = {handle, row, column};
118171

119172
Value res =
120-
rewriter.create<mlir::func::CallOp>(loop.getLoc(), valuefn, args)
173+
rewriter.create<mlir::func::CallOp>(op.getLoc(), valuefn, args)
121174
->getResult(0);
122175

123176
Value args2[] = {res};
124177

125178
Value res2 =
126-
rewriter.create<mlir::func::CallOp>(loop.getLoc(), atoifn, args2)
179+
rewriter.create<mlir::func::CallOp>(op.getLoc(), atoifn, args2)
127180
->getResult(0);
128181

129-
if (loop.getType() != res2.getType()) {
130-
if (loop.getType().isa<IndexType>())
131-
res2 = rewriter.create<arith::IndexCastOp>(loop.getLoc(),
132-
loop.getType(), res2);
133-
else if (auto IT = loop.getType().dyn_cast<IntegerType>()) {
182+
if (op.getType() != res2.getType()) {
183+
if (op.getType().isa<IndexType>())
184+
res2 = rewriter.create<arith::IndexCastOp>(op.getLoc(),
185+
op.getType(), res2);
186+
else if (auto IT = op.getType().dyn_cast<IntegerType>()) {
134187
auto IT2 = res2.getType().dyn_cast<IntegerType>();
135188
if (IT.getWidth() < IT2.getWidth()) {
136-
res2 = rewriter.create<arith::TruncIOp>(loop.getLoc(),
137-
loop.getType(), res2);
189+
res2 = rewriter.create<arith::TruncIOp>(op.getLoc(),
190+
op.getType(), res2);
138191
} else if (IT.getWidth() > IT2.getWidth()) {
139-
res2 = rewriter.create<arith::ExtUIOp>(loop.getLoc(),
140-
loop.getType(), res2);
192+
res2 = rewriter.create<arith::ExtUIOp>(op.getLoc(),
193+
op.getType(), res2);
141194
} else assert(0 && "illegal integer type conversion");
142195
} else {
143196
assert(0 && "illegal type conversion");
144197
}
145198
}
146-
rewriter.replaceOp(loop, res2);
199+
rewriter.replaceOp(op, res2);
147200

148201
return success();
149202
}
150203
};
151204

152205

153206

154-
struct ExecuteOpLowering : public OpRewritePattern<sql::ExecuteOp> {
155-
using OpRewritePattern<sql::ExecuteOp>::OpRewritePattern;
207+
// struct ExecuteOpLowering : public OpRewritePattern<sql::ExecuteOp> {
208+
// using OpRewritePattern<sql::ExecuteOp>::OpRewritePattern;
156209

157-
LogicalResult matchAndRewrite(sql::ExecuteOp loop,
158-
PatternRewriter &rewriter) const final {
159-
auto module = loop->getParentOfType<ModuleOp>();
210+
// LogicalResult matchAndRewrite(sql::ExecuteOp op,
211+
// PatternRewriter &rewriter) const final {
212+
// auto module = op->getParentOfType<ModuleOp>();
160213

161-
SymbolTableCollection symbolTable;
162-
symbolTable.getSymbolTable(module);
214+
// SymbolTableCollection symbolTable;
215+
// symbolTable.getSymbolTable(module);
163216

164-
// 1) make sure the postgres_getresult function is declared
165-
auto executefn = dyn_cast_or_null<func::FuncOp>(
166-
symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("PQexec")));
217+
// // 1) make sure the postgres_getresult function is declared
218+
// auto executefn = dyn_cast_or_null<func::FuncOp>(
219+
// symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("PQexec")));
167220

168-
// 2) convert the args to valid args to postgres_getresult abi
169-
Value conn = loop.getConn();
170-
conn = rewriter.create<arith::IndexCastOp>(loop.getLoc(),
171-
rewriter.getI64Type(), conn);
172-
conn = rewriter.create<LLVM::IntToPtrOp>(
173-
loop.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), conn);
221+
// // 2) convert the args to valid args to postgres_getresult abi
222+
// Value conn = op.getConn();
223+
// conn = rewriter.create<arith::IndexCastOp>(op.getLoc(),
224+
// rewriter.getI64Type(), conn);
225+
// conn = rewriter.create<LLVM::IntToPtrOp>(
226+
// op.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), conn);
174227

175-
Value command = loop.getCommand();
228+
// Value command = op.getCommand();
176229

177-
// 3) call and replace
178-
Value args[] = {conn, command};
230+
// // 3) call and replace
231+
// Value args[] = {conn, command};
179232

180-
Value res =
181-
rewriter.create<mlir::func::CallOp>(loop.getLoc(), executefn, args)
182-
->getResult(0);
233+
// Value res =
234+
// rewriter.create<mlir::func::CallOp>(op.getLoc(), executefn, args)
235+
// ->getResult(0);
183236

184-
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
185-
loop, rewriter.getIndexType(), res);
237+
// rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
238+
// op, rewriter.getIndexType(), res);
186239

187-
return success();
188-
}
189-
};
240+
// return success();
241+
// }
242+
// };
190243

191244
void SQLLower::runOnOperation() {
192245
auto module = getOperation();
@@ -197,12 +250,12 @@ void SQLLower::runOnOperation() {
197250
builder.setInsertionPointToStart(module.getBody());
198251

199252
if (!dyn_cast_or_null<func::FuncOp>(symbolTable.lookupSymbolIn(
200-
module, builder.getStringAttr("PQcmdTuples")))) {
253+
module, builder.getStringAttr("PQntuples")))) {
201254
mlir::Type argtypes[] = {LLVM::LLVMPointerType::get(builder.getI8Type())};
202-
mlir::Type rettypes[] = {LLVM::LLVMPointerType::get(builder.getI8Type())};
255+
mlir::Type rettypes[] = {LLVM::LLVMPointerType::get(builder.getI64Type())};
203256

204257
auto fn =
205-
builder.create<func::FuncOp>(module.getLoc(), "PQcmdTuples",
258+
builder.create<func::FuncOp>(module.getLoc(), "PQntuples",
206259
builder.getFunctionType(argtypes, rettypes));
207260
SymbolTable::setSymbolVisibility(fn, SymbolTable::Visibility::Private);
208261
}
@@ -221,20 +274,21 @@ void SQLLower::runOnOperation() {
221274
SymbolTable::setSymbolVisibility(fn, SymbolTable::Visibility::Private);
222275
}
223276

224-
// if (!dyn_cast_or_null<func::FuncOp>(
225-
// symbolTable.lookupSymbolIn(module, builder.getStringAttr("PQexec")))) {
226-
// mlir::Type argtypes[] = {LLVM::LLVMPointerType::get(builder.getI32Type()),
227-
// LLVM::LLVMPointerType::get(builder.getI8Type())};
228-
// mlir::Type rettypes[] = {LLVM::LLVMPointerType::get(builder.getI32Type())};
277+
// if (!dyn_cast_or_null<func::FuncOp>(
278+
// symbolTable.lookupSymbolIn(module, builder.getStringAttr("PQexec")))) {
279+
// mlir::Type argtypes[] = {LLVM::LLVMPointerType::get(builder.getI64Type()),
280+
// LLVM::LLVMPointerType::get(builder.getI8Type())};
281+
// mlir::Type rettypes[] = {LLVM::LLVMPointerType::get(builder.getI64Type())};
229282

230-
// auto fn = builder.create<func::FuncOp>(
231-
// module.getLoc(), "PQexec", builder.getFunctionType(argtypes, rettypes));
232-
// SymbolTable::setSymbolVisibility(fn, SymbolTable::Visibility::Private);
233-
// }
283+
// auto fn = builder.create<func::FuncOp>(
284+
// module.getLoc(), "PQexec", builder.getFunctionType(argtypes, rettypes));
285+
// SymbolTable::setSymbolVisibility(fn, SymbolTable::Visibility::Private);
286+
// }
234287

235288
if (!dyn_cast_or_null<func::FuncOp>(
236289
symbolTable.lookupSymbolIn(module, builder.getStringAttr("atoi")))) {
237-
mlir::Type argtypes[] = {MemRefType::get({-1}, builder.getI8Type())};
290+
// mlir::Type argtypes[] = {MemRefType::get({-1}, builder.getI8Type())};
291+
mlir::Type argtypes[] = {LLVM::LLVMPointerType::get(builder.getI64Type())};
238292

239293
// todo use data layout
240294
mlir::Type rettypes[] = {builder.getI64Type()};

0 commit comments

Comments
 (0)