@@ -161,16 +161,26 @@ struct ConstantStringOpLowering : public OpRewritePattern<sql::SQLConstantString
161
161
162
162
SymbolTableCollection symbolTable;
163
163
symbolTable.getSymbolTable (module );
164
-
165
- auto expr = (op.getInput () + " \0 " ).str ();
164
+ for (auto u: op.getResult ().getUsers ()){
165
+ if (isa<sql::SQLStringConcatOp>(u)) return failure ();
166
+ }
167
+ auto expr = op.getInput ().str ();
166
168
auto name = " str" + std::to_string ((long long int )(Operation *)op);
167
- auto MT = MemRefType::get ({expr.size ()}, rewriter.getI8Type ());
169
+ auto MT = MemRefType::get ({expr.size () + 1 }, rewriter.getI8Type ());
170
+ // auto type = MemRefType::get(mt.getShape(), mt.getElementType(), {});
168
171
auto getglob = rewriter.create <memref::GetGlobalOp>(op.getLoc (), MT, name);
172
+
173
+ SmallVector<char , 1 > data (expr.begin (), expr.end ());
174
+ data.push_back (' \0 ' );
175
+ auto attr = DenseElementsAttr::get<char >(
176
+ RankedTensorType::get (MT.getShape (), MT.getElementType ()), data);
169
177
170
- rewriter.setInsertionPointToStart (module .getBody ());
171
- auto res = rewriter.create <memref::GlobalOp>(op.getLoc (), rewriter.getStringAttr (name),
172
- mlir::StringAttr (), mlir::TypeAttr::get (MT), rewriter.getStringAttr (expr), mlir::UnitAttr (), /* alignment*/ nullptr );
178
+ auto loc = op.getLoc ();
173
179
rewriter.replaceOpWithNewOp <memref::CastOp>(op, MemRefType::get ({-1 }, rewriter.getI8Type ()), getglob.getResult ());
180
+ rewriter.setInsertionPointToStart (module .getBody ());
181
+ auto res = rewriter.create <memref::GlobalOp>(loc, rewriter.getStringAttr (name),
182
+ mlir::StringAttr (), mlir::TypeAttr::get (MT), attr, rewriter.getUnitAttr (), /* alignment*/ nullptr );
183
+
174
184
return success ();
175
185
}
176
186
};
@@ -207,22 +217,27 @@ struct ToStringOpLowering : public OpRewritePattern<sql::SQLToStringOp> {
207
217
Value current = rewriter.create <sql::SQLConstantStringOp>(op.getLoc (),
208
218
MemRefType::get ({-1 }, rewriter.getI8Type ()), " SELECT " );
209
219
bool prevColumn = false ;
210
- for (auto v : selectOp.getColumns ()) {
211
- Value columns = rewriter.create <sql::SQLToStringOp>(op.getLoc (),
212
- MemRefType::get ({-1 }, rewriter.getI8Type ()), v);
213
- Value args[] = { current, columns };
214
- current = rewriter.create <sql::SQLStringConcatOp>(op.getLoc (),
215
- MemRefType::get ({-1 }, rewriter.getI8Type ()),args);
220
+ auto columns = selectOp.getColumns ();
221
+ for (mlir::Value v : columns) {
216
222
if (prevColumn) {
217
- Value args[] = { current, rewriter.create <sql::SQLConstantStringOp>(op.getLoc (),
218
- MemRefType::get ({-1 }, rewriter.getI8Type ()), " , " ) };
219
- current = rewriter.create <sql::SQLStringConcatOp>(op.getLoc (),
220
- MemRefType::get ({-1 }, rewriter.getI8Type ()),args);
223
+ Value args[] = { current, rewriter.create <sql::SQLConstantStringOp>(op.getLoc (),
224
+ MemRefType::get ({-1 }, rewriter.getI8Type ()), " , " ) };
225
+ current = rewriter.create <sql::SQLStringConcatOp>(op.getLoc (),
226
+ MemRefType::get ({-1 }, rewriter.getI8Type ()),args);
221
227
}
228
+ Value col = rewriter.create <sql::SQLToStringOp>(op.getLoc (),
229
+ MemRefType::get ({-1 }, rewriter.getI8Type ()), v);
230
+ Value args[] = { col, rewriter.create <sql::SQLConstantStringOp>(op.getLoc (),
231
+ MemRefType::get ({-1 }, rewriter.getI8Type ()), " " )};
232
+ col = rewriter.create <sql::SQLStringConcatOp>(op.getLoc (),
233
+ MemRefType::get ({-1 }, rewriter.getI8Type ()), args);
234
+ Value args2[] = { current, col };
235
+ current = rewriter.create <sql::SQLStringConcatOp>(op.getLoc (),
236
+ MemRefType::get ({-1 }, rewriter.getI8Type ()), args2);
222
237
prevColumn = true ;
223
238
}
224
239
auto tableOp = selectOp.getTable ().getDefiningOp <sql::TableOp>();
225
- if (! tableOp || !tableOp. getExpr (). empty () ) {
240
+ if (tableOp) {
226
241
Value args[] = { current, rewriter.create <sql::SQLConstantStringOp>(op.getLoc (),
227
242
MemRefType::get ({-1 }, rewriter.getI8Type ()), " FROM " ) };
228
243
current = rewriter.create <sql::SQLStringConcatOp>(op.getLoc (),
@@ -233,6 +248,16 @@ struct ToStringOpLowering : public OpRewritePattern<sql::SQLToStringOp> {
233
248
MemRefType::get ({-1 }, rewriter.getI8Type ()),args2);
234
249
}
235
250
rewriter.replaceOp (op, current);
251
+ } else if (auto selectAllOp = dyn_cast<sql::SelectAllOp>(definingOp)){
252
+ auto table = rewriter.create <sql::SQLToStringOp>(op.getLoc (),
253
+ MemRefType::get ({-1 }, rewriter.getI8Type ()), selectAllOp.getTable ());
254
+ Value res = rewriter.create <sql::SQLConstantStringOp>(op.getLoc (),
255
+ MemRefType::get ({-1 }, rewriter.getI8Type ()), " SELECT * FROM " );
256
+ Value args[] = { res, table };
257
+ res = rewriter.create <sql::SQLStringConcatOp>(op.getLoc (),
258
+ MemRefType::get ({-1 }, rewriter.getI8Type ()),args);
259
+
260
+ rewriter.replaceOp (op, res);
236
261
} else if (auto tabOp = dyn_cast<sql::TableOp>(definingOp)) {
237
262
Value res = rewriter.create <sql::SQLConstantStringOp>(op.getLoc (),
238
263
MemRefType::get ({-1 }, rewriter.getI8Type ()), tabOp.getExpr ());
@@ -244,6 +269,7 @@ struct ToStringOpLowering : public OpRewritePattern<sql::SQLToStringOp> {
244
269
} else if (auto intOp = dyn_cast<sql::IntOp>(definingOp)){
245
270
Value res = rewriter.create <sql::SQLConstantStringOp>(op.getLoc (),
246
271
MemRefType::get ({-1 }, rewriter.getI8Type ()), intOp.getExpr ());
272
+ llvm::errs () << " intOp: " << intOp.getExpr () << " \n " ;
247
273
rewriter.replaceOp (op, res);
248
274
} else {
249
275
assert (0 && " unknown type to convert to string" );
@@ -280,6 +306,8 @@ struct ExecuteOpLowering : public OpRewritePattern<sql::ExecuteOp> {
280
306
// auto name = "str" + std::to_string((long long int)(Operation *)command.getDefiningOp());
281
307
command = rewriter.create <sql::SQLToStringOp>(op.getLoc (),
282
308
MemRefType::get ({-1 }, rewriter.getI8Type ()), command);
309
+ llvm::errs () << " command: " << command << " \n " ;
310
+ llvm::errs () << " command type: " << command.getType () << " \n " ;
283
311
// auto type = MemRefType::get({-1}, rewriter.getI8Type());
284
312
285
313
@@ -295,6 +323,12 @@ struct ExecuteOpLowering : public OpRewritePattern<sql::ExecuteOp> {
295
323
Value res =
296
324
rewriter.create <mlir::func::CallOp>(op.getLoc (), executefn, args)
297
325
->getResult (0 );
326
+ res = rewriter.create <polygeist::Memref2PointerOp>(op.getLoc (),
327
+ LLVM::LLVMPointerType::get (rewriter.getI8Type ()), res);
328
+ res = rewriter.create <LLVM::PtrToIntOp>(
329
+ op.getLoc (), rewriter.getI64Type (), res);
330
+ res = rewriter.create <arith::IndexCastOp>(op.getLoc (),
331
+ op.getType (), res);
298
332
299
333
rewriter.replaceOp (op, res);
300
334
0 commit comments