@@ -48,7 +48,7 @@ struct NumResultsOpLowering : public OpRewritePattern<sql::NumResultsOp> {
48
48
auto module = loop->getParentOfType <ModuleOp>();
49
49
50
50
SymbolTableCollection symbolTable;
51
- symbolTable.getSymbolTable (loop );
51
+ symbolTable.getSymbolTable (module );
52
52
53
53
// 1) make sure the postgres_getresult function is declared
54
54
auto rowsfn = dyn_cast_or_null<func::FuncOp>(
@@ -80,7 +80,110 @@ struct NumResultsOpLowering : public OpRewritePattern<sql::NumResultsOp> {
80
80
rewriter.replaceOpWithNewOp <arith::IndexCastOp>(
81
81
loop, rewriter.getIndexType (), res2);
82
82
83
- // 4) done
83
+ return success ();
84
+ }
85
+ };
86
+
87
+
88
+ struct GetValueOpLowering : public OpRewritePattern <sql::GetValueOp> {
89
+ using OpRewritePattern<sql::GetValueOp>::OpRewritePattern;
90
+
91
+ LogicalResult matchAndRewrite (sql::GetValueOp loop,
92
+ PatternRewriter &rewriter) const final {
93
+ auto module = loop->getParentOfType <ModuleOp>();
94
+
95
+ SymbolTableCollection symbolTable;
96
+ symbolTable.getSymbolTable (module );
97
+
98
+ // 1) make sure the postgres_getresult function is declared
99
+ auto valuefn = dyn_cast_or_null<func::FuncOp>(
100
+ symbolTable.lookupSymbolIn (module , rewriter.getStringAttr (" PQgetvalue" )));
101
+
102
+ auto atoifn = dyn_cast_or_null<func::FuncOp>(
103
+ symbolTable.lookupSymbolIn (module , rewriter.getStringAttr (" atoi" )));
104
+
105
+ // 2) convert the args to valid args to postgres_getresult abi
106
+ Value handle = loop.getHandle ();
107
+ handle = rewriter.create <arith::IndexCastOp>(loop.getLoc (),
108
+ rewriter.getI64Type (), handle);
109
+ handle = rewriter.create <LLVM::IntToPtrOp>(
110
+ loop.getLoc (), LLVM::LLVMPointerType::get (rewriter.getI8Type ()), handle);
111
+
112
+ Value row = loop.getRow ();
113
+ Value column = loop.getColumn ();
114
+
115
+
116
+ // 3) call and replace
117
+ Value args[] = {handle, row, column};
118
+
119
+ Value res =
120
+ rewriter.create <mlir::func::CallOp>(loop.getLoc (), valuefn, args)
121
+ ->getResult (0 );
122
+
123
+ Value args2[] = {res};
124
+
125
+ Value res2 =
126
+ rewriter.create <mlir::func::CallOp>(loop.getLoc (), atoifn, args2)
127
+ ->getResult (0 );
128
+
129
+ if (loop.getType () != res2.getType ()) {
130
+ if (loop.getType ().isa <IndexType>())
131
+ res2 = rewriter.create <arith::IndexCastOp>(loop.getLoc (),
132
+ loop.getType (), res2);
133
+ else if (auto IT = loop.getType ().dyn_cast <IntegerType>()) {
134
+ auto IT2 = res2.getType ().dyn_cast <IntegerType>();
135
+ if (IT.getWidth () < IT2.getWidth ()) {
136
+ res2 = rewriter.create <arith::TruncIOp>(loop.getLoc (),
137
+ loop.getType (), res2);
138
+ } else if (IT.getWidth () > IT2.getWidth ()) {
139
+ res2 = rewriter.create <arith::ExtUIOp>(loop.getLoc (),
140
+ loop.getType (), res2);
141
+ } else assert (0 && " illegal integer type conversion" );
142
+ } else {
143
+ assert (0 && " illegal type conversion" );
144
+ }
145
+ }
146
+ rewriter.replaceOp (loop, res2);
147
+
148
+ return success ();
149
+ }
150
+ };
151
+
152
+
153
+
154
+ struct ExecuteOpLowering : public OpRewritePattern <sql::ExecuteOp> {
155
+ using OpRewritePattern<sql::ExecuteOp>::OpRewritePattern;
156
+
157
+ LogicalResult matchAndRewrite (sql::ExecuteOp loop,
158
+ PatternRewriter &rewriter) const final {
159
+ auto module = loop->getParentOfType <ModuleOp>();
160
+
161
+ SymbolTableCollection symbolTable;
162
+ symbolTable.getSymbolTable (module );
163
+
164
+ // 1) make sure the postgres_getresult function is declared
165
+ auto executefn = dyn_cast_or_null<func::FuncOp>(
166
+ symbolTable.lookupSymbolIn (module , rewriter.getStringAttr (" PQexec" )));
167
+
168
+ // 2) convert the args to valid args to postgres_getresult abi
169
+ Value conn = loop.getConn ();
170
+ conn = rewriter.create <arith::IndexCastOp>(loop.getLoc (),
171
+ rewriter.getI64Type (), conn);
172
+ conn = rewriter.create <LLVM::IntToPtrOp>(
173
+ loop.getLoc (), LLVM::LLVMPointerType::get (rewriter.getI8Type ()), conn);
174
+
175
+ Value command = loop.getCommand ();
176
+
177
+ // 3) call and replace
178
+ Value args[] = {conn, command};
179
+
180
+ Value res =
181
+ rewriter.create <mlir::func::CallOp>(loop.getLoc (), executefn, args)
182
+ ->getResult (0 );
183
+
184
+ rewriter.replaceOpWithNewOp <arith::IndexCastOp>(
185
+ loop, rewriter.getIndexType (), res);
186
+
84
187
return success ();
85
188
}
86
189
};
@@ -103,9 +206,35 @@ void SQLLower::runOnOperation() {
103
206
builder.getFunctionType (argtypes, rettypes));
104
207
SymbolTable::setSymbolVisibility (fn, SymbolTable::Visibility::Private);
105
208
}
209
+
210
+ if (!dyn_cast_or_null<func::FuncOp>(symbolTable.lookupSymbolIn (
211
+ module , builder.getStringAttr (" PQgetvalue" )))) {
212
+ mlir::Type argtypes[] = {
213
+ LLVM::LLVMPointerType::get (builder.getI8Type ()),
214
+ LLVM::LLVMPointerType::get (builder.getI64Type ()),
215
+ LLVM::LLVMPointerType::get (builder.getI64Type ())};
216
+ mlir::Type rettypes[] = {LLVM::LLVMPointerType::get (builder.getI8Type ())};
217
+
218
+ auto fn =
219
+ builder.create <func::FuncOp>(module .getLoc (), " PQgetvalue" ,
220
+ builder.getFunctionType (argtypes, rettypes));
221
+ SymbolTable::setSymbolVisibility (fn, SymbolTable::Visibility::Private);
222
+ }
223
+
224
+ // if (!dyn_cast_or_null<func::FuncOp>(
225
+ // symbolTable.lookupSymbolIn(module, builder.getStringAttr("PQexec")))) {
226
+ // mlir::Type argtypes[] = {LLVM::LLVMPointerType::get(builder.getI32Type()),
227
+ // LLVM::LLVMPointerType::get(builder.getI8Type())};
228
+ // mlir::Type rettypes[] = {LLVM::LLVMPointerType::get(builder.getI32Type())};
229
+
230
+ // auto fn = builder.create<func::FuncOp>(
231
+ // module.getLoc(), "PQexec", builder.getFunctionType(argtypes, rettypes));
232
+ // SymbolTable::setSymbolVisibility(fn, SymbolTable::Visibility::Private);
233
+ // }
234
+
106
235
if (!dyn_cast_or_null<func::FuncOp>(
107
236
symbolTable.lookupSymbolIn (module , builder.getStringAttr (" atoi" )))) {
108
- mlir::Type argtypes[] = {LLVM::LLVMPointerType:: get (builder.getI8Type ())};
237
+ mlir::Type argtypes[] = {MemRefType:: get ({- 1 }, builder.getI8Type ())};
109
238
110
239
// todo use data layout
111
240
mlir::Type rettypes[] = {builder.getI64Type ()};
@@ -117,6 +246,8 @@ void SQLLower::runOnOperation() {
117
246
118
247
RewritePatternSet patterns (&getContext ());
119
248
patterns.insert <NumResultsOpLowering>(&getContext ());
249
+ patterns.insert <GetValueOpLowering>(&getContext ());
250
+ // patterns.insert<ExecuteOpLowering>(&getContext());
120
251
121
252
GreedyRewriteConfig config;
122
253
(void )applyPatternsAndFoldGreedily (getOperation (), std::move (patterns),
0 commit comments