Skip to content

Commit 5249dbe

Browse files
committed
basic select
1 parent 2b568ce commit 5249dbe

File tree

4 files changed

+116
-39
lines changed

4 files changed

+116
-39
lines changed

include/sql/SQLOps.td

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,6 @@ def ColumnOp : SQL_Op<"column", [Pure]> {
3838
let hasCanonicalizer = 0;
3939
}
4040

41-
def AllColumnsOp : SQL_Op<"all_columns", [Pure]> {
42-
let summary = "all columns op";
43-
44-
let arguments = (ins StrAttr:$expr);
45-
let results = (outs SQLExprType:$result);
46-
47-
let hasFolder = 0;
48-
let hasCanonicalizer = 0;
49-
}
50-
5141
def WhereOp: SQL_Op<"where", [Pure]> {
5242
let summary = "where op";
5343

@@ -130,16 +120,39 @@ def SQLStringConcatOp : SQL_Op<"string_concat", [Pure]> {
130120
let hasCanonicalizer = 1;
131121
}
132122

123+
def ConstantBoolOp : SQL_Op <"constant_bool", [Pure]> {
124+
let summary = "constant_bool";
125+
let results = (outs SQLBoolType:$result);
126+
127+
}
128+
129+
133130
def SelectOp : SQL_Op<"select", [Pure]> {
134131
let summary = "select";
135132
// i need to specify the size of a Variadic?
136-
let arguments = (ins Variadic<SQLExprType>:$columns, SQLExprType:$table);
133+
let arguments = (ins Variadic<SQLExprType>:$columns,
134+
SQLExprType:$table,
135+
// SQLBoolType:$where,
136+
BoolAttr:$selectAll
137+
IntAttr:$limit);
138+
// attribute limit<int> if >= 0 then its the real thing, otherwise its infinity
139+
let results = (outs SQLExprType:$result);
140+
141+
let hasFolder = 0;
142+
let hasCanonicalizer = 0;
143+
}
144+
145+
def AllColumnsOp : SQL_Op<"all_columns", [Pure]> {
146+
let summary = "all_columns";
137147
let results = (outs SQLExprType:$result);
138148

139149
let hasFolder = 0;
140150
let hasCanonicalizer = 0;
141151
}
142152

153+
154+
155+
143156
def UnparsedOp : SQL_Op<"unparsed", [Pure]> {
144157
let summary = "unparsed sql op";
145158

include/sql/SQLTypes.td

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ def SQLExprType : SQL_Type<"Expr", "expr"> {
2828
// let assemblyFormat = "`<` $value `>`";
2929
}
3030

31-
3231
def SQLBoolType : SQL_Type<"Bool", "bool"> {
3332
let summary = "SQL boolean type";
3433
let description = "Custom attr or value type in sql dialect";

lib/sql/Parser.cpp

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ class SQLParser {
106106
return {"", 0};
107107
}
108108
for (std::string rWord : reservedWords) {
109-
auto token = sql.substr(i, std::min(sql.size(), i + rWord.size()));
109+
auto token = sql.substr(i, std::min(sql.size()-i, rWord.size()));
110110
std::transform(token.begin(), token.end(), token.begin(), ::toupper);
111111
if (token == rWord) {
112112
return {token, static_cast<int>(token.size())};
@@ -148,12 +148,14 @@ class SQLParser {
148148
}
149149

150150

151-
152151
// Parse the next command, if any
153152
ParseValue parseNext(ParseMode mode) {
153+
// for (unsigned int j = i; j < sql.size(); j++) {
154+
// auto peekStr = peek();
155+
// pop();
156+
// llvm::errs() << "peekStrTest: " << i << " " << peekStr << "\n";
157+
// }
154158
if (i >= sql.size()) {
155-
llvm::errs() << "here i:" << i << "\n";
156-
llvm::errs() << "here size:" << sql.size() << "\n";
157159
return ParseValue();
158160
}
159161
auto peekStr = peek();
@@ -171,35 +173,64 @@ class SQLParser {
171173
llvm::SmallVector<Value> columns;
172174
bool hasColumns = true;
173175
bool hasWhere = false;
176+
bool selectAll = false;
177+
int limit = -1;
174178
Value table = nullptr;
175179
while (true) {
176180
peekStr = peek();
177181
if (hasColumns) {
178182
if (peekStr == "FROM") {
179183
pop();
180184
table = parseNext(ParseMode::Table).getValue();
185+
llvm::errs() << "table: " << table << "\n";
181186
hasColumns = false;
182187
break;
183188
}
189+
if (peekStr == "*") {
190+
pop();
191+
selectAll = true;
192+
continue;
193+
}
194+
if (peekStr == ",") {
195+
pop();
196+
llvm::errs() << "comma\n";
197+
continue;
198+
}
184199
ParseValue col = parseNext(ParseMode::Column);
185200
if (col.getType() == ParseType::Nothing) {
186201
hasColumns = false;
187202
break;
188203
} else {
189204
columns.push_back(col.getValue());
190205
}
191-
if (peekStr == ",") pop();
206+
192207
} else if (peekStr == "WHERE") {
193208
pop();
194209
hasWhere = true;
210+
} else if (peekStr == "LIMIT"){
211+
pop();
212+
peekStr = peek();
213+
if (peekStr == "ALL"){
214+
pop();
215+
} else if (is_number(&peekStr)){
216+
pop();
217+
limit = std::stoi(peekStr);
218+
}
195219
} else {
196-
break;
197-
// assert(0 && " additional clauses like limit/etc not yet handled");
220+
// break;
221+
assert(0 && " additional clauses like where/etc not yet handled");
198222
}
199223
}
200-
if (!table)
224+
if (!table){
225+
llvm::errs() << " table is null: " << table << "\n";
201226
table = builder.create<sql::TableOp>(loc, ExprType::get(builder.getContext()), builder.getStringAttr("")).getResult();
202-
return ParseValue(builder.create<sql::SelectOp>(loc, ExprType::get(builder.getContext()), columns, table).getResult());
227+
}
228+
// if (selectAll){
229+
// assert(table && "table cannot be null");
230+
// return ParseValue(builder.create<sql::SelectAllOp>(loc, ExprType::get(builder.getContext()), table).getResult());
231+
// } else {
232+
return ParseValue(builder.create<sql::SelectOp>(loc, ExprType::get(builder.getContext()), columns, table, selectALl, limit).getResult());
233+
// }
203234
} else if (is_number(&peekStr)){
204235
pop();
205236
return ParseValue(builder.create<IntOp>(loc, ExprType::get(builder.getContext()), builder.getStringAttr(peekStr)).getResult());
@@ -234,7 +265,7 @@ class SQLParser {
234265
};
235266

236267
std::vector<std::string> SQLParser::reservedWords = {
237-
"(", ")", ">=", "<=", "!=", ",", "=", ">", "<", "SELECT", "DISTINCT", "INSERT INTO", "VALUES", "UPDATE", "DELETE FROM", "WHERE", "FROM", "SET", "AS"
268+
"(", ")", ">=", "<=", "!=", ",", "=", ">", "<", ",", "SELECT", "DISTINCT", "INSERT INTO", "VALUES", "UPDATE", "DELETE FROM", "WHERE", "FROM", "SET", "AS"
238269
};
239270

240271

lib/sql/Passes/SQLLower.cpp

Lines changed: 51 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -161,16 +161,26 @@ struct ConstantStringOpLowering : public OpRewritePattern<sql::SQLConstantString
161161

162162
SymbolTableCollection symbolTable;
163163
symbolTable.getSymbolTable(module);
164-
165-
auto expr = (op.getInput() + "\0").str();
164+
for (auto u: op.getResult().getUsers()){
165+
if (isa<sql::SQLStringConcatOp>(u)) return failure();
166+
}
167+
auto expr = op.getInput().str();
166168
auto name = "str" + std::to_string((long long int)(Operation *)op);
167-
auto MT = MemRefType::get({expr.size()}, rewriter.getI8Type());
169+
auto MT = MemRefType::get({expr.size() + 1}, rewriter.getI8Type());
170+
// auto type = MemRefType::get(mt.getShape(), mt.getElementType(), {});
168171
auto getglob = rewriter.create<memref::GetGlobalOp>(op.getLoc(), MT, name);
172+
173+
SmallVector<char, 1> data(expr.begin(), expr.end());
174+
data.push_back('\0');
175+
auto attr = DenseElementsAttr::get<char>(
176+
RankedTensorType::get(MT.getShape(), MT.getElementType()), data);
169177

170-
rewriter.setInsertionPointToStart(module.getBody());
171-
auto res = rewriter.create<memref::GlobalOp>(op.getLoc(), rewriter.getStringAttr(name),
172-
mlir::StringAttr(), mlir::TypeAttr::get(MT), rewriter.getStringAttr(expr), mlir::UnitAttr(), /*alignment*/nullptr);
178+
auto loc = op.getLoc();
173179
rewriter.replaceOpWithNewOp<memref::CastOp>(op, MemRefType::get({-1}, rewriter.getI8Type()), getglob.getResult());
180+
rewriter.setInsertionPointToStart(module.getBody());
181+
auto res = rewriter.create<memref::GlobalOp>(loc, rewriter.getStringAttr(name),
182+
mlir::StringAttr(), mlir::TypeAttr::get(MT), attr, rewriter.getUnitAttr(), /*alignment*/nullptr);
183+
174184
return success();
175185
}
176186
};
@@ -207,22 +217,27 @@ struct ToStringOpLowering : public OpRewritePattern<sql::SQLToStringOp> {
207217
Value current = rewriter.create<sql::SQLConstantStringOp>(op.getLoc(),
208218
MemRefType::get({-1}, rewriter.getI8Type()), "SELECT ");
209219
bool prevColumn = false;
210-
for (auto v : selectOp.getColumns()) {
211-
Value columns = rewriter.create<sql::SQLToStringOp>(op.getLoc(),
212-
MemRefType::get({-1}, rewriter.getI8Type()), v);
213-
Value args[] = { current, columns };
214-
current = rewriter.create<sql::SQLStringConcatOp>(op.getLoc(),
215-
MemRefType::get({-1}, rewriter.getI8Type()),args);
220+
auto columns = selectOp.getColumns();
221+
for (mlir::Value v : columns) {
216222
if (prevColumn) {
217-
Value args[] = { current, rewriter.create<sql::SQLConstantStringOp>(op.getLoc(),
218-
MemRefType::get({-1}, rewriter.getI8Type()), ", ") };
219-
current = rewriter.create<sql::SQLStringConcatOp>(op.getLoc(),
220-
MemRefType::get({-1}, rewriter.getI8Type()),args);
223+
Value args[] = { current, rewriter.create<sql::SQLConstantStringOp>(op.getLoc(),
224+
MemRefType::get({-1}, rewriter.getI8Type()), ", ") };
225+
current = rewriter.create<sql::SQLStringConcatOp>(op.getLoc(),
226+
MemRefType::get({-1}, rewriter.getI8Type()),args);
221227
}
228+
Value col = rewriter.create<sql::SQLToStringOp>(op.getLoc(),
229+
MemRefType::get({-1}, rewriter.getI8Type()), v);
230+
Value args[] = { col, rewriter.create<sql::SQLConstantStringOp>(op.getLoc(),
231+
MemRefType::get({-1}, rewriter.getI8Type()), " ")};
232+
col = rewriter.create<sql::SQLStringConcatOp>(op.getLoc(),
233+
MemRefType::get({-1}, rewriter.getI8Type()), args);
234+
Value args2[] = { current, col };
235+
current = rewriter.create<sql::SQLStringConcatOp>(op.getLoc(),
236+
MemRefType::get({-1}, rewriter.getI8Type()), args2);
222237
prevColumn = true;
223238
}
224239
auto tableOp = selectOp.getTable().getDefiningOp<sql::TableOp>();
225-
if (!tableOp || !tableOp.getExpr().empty()) {
240+
if (tableOp) {
226241
Value args[] = { current, rewriter.create<sql::SQLConstantStringOp>(op.getLoc(),
227242
MemRefType::get({-1}, rewriter.getI8Type()), "FROM ") };
228243
current = rewriter.create<sql::SQLStringConcatOp>(op.getLoc(),
@@ -233,6 +248,16 @@ struct ToStringOpLowering : public OpRewritePattern<sql::SQLToStringOp> {
233248
MemRefType::get({-1}, rewriter.getI8Type()),args2);
234249
}
235250
rewriter.replaceOp(op, current);
251+
} else if (auto selectAllOp = dyn_cast<sql::SelectAllOp>(definingOp)){
252+
auto table = rewriter.create<sql::SQLToStringOp>(op.getLoc(),
253+
MemRefType::get({-1}, rewriter.getI8Type()), selectAllOp.getTable());
254+
Value res = rewriter.create<sql::SQLConstantStringOp>(op.getLoc(),
255+
MemRefType::get({-1}, rewriter.getI8Type()), "SELECT * FROM ");
256+
Value args[] = { res, table };
257+
res = rewriter.create<sql::SQLStringConcatOp>(op.getLoc(),
258+
MemRefType::get({-1}, rewriter.getI8Type()),args);
259+
260+
rewriter.replaceOp(op, res);
236261
} else if (auto tabOp = dyn_cast<sql::TableOp>(definingOp)) {
237262
Value res = rewriter.create<sql::SQLConstantStringOp>(op.getLoc(),
238263
MemRefType::get({-1}, rewriter.getI8Type()), tabOp.getExpr());
@@ -244,6 +269,7 @@ struct ToStringOpLowering : public OpRewritePattern<sql::SQLToStringOp> {
244269
} else if (auto intOp = dyn_cast<sql::IntOp>(definingOp)){
245270
Value res = rewriter.create<sql::SQLConstantStringOp>(op.getLoc(),
246271
MemRefType::get({-1}, rewriter.getI8Type()), intOp.getExpr());
272+
llvm::errs() << "intOp: " << intOp.getExpr() << "\n";
247273
rewriter.replaceOp(op, res);
248274
} else {
249275
assert(0 && "unknown type to convert to string");
@@ -280,6 +306,8 @@ struct ExecuteOpLowering : public OpRewritePattern<sql::ExecuteOp> {
280306
// auto name = "str" + std::to_string((long long int)(Operation *)command.getDefiningOp());
281307
command = rewriter.create<sql::SQLToStringOp>(op.getLoc(),
282308
MemRefType::get({-1}, rewriter.getI8Type()), command);
309+
llvm::errs() << "command: " << command << "\n";
310+
llvm::errs() << "command type: " << command.getType() << "\n";
283311
// auto type = MemRefType::get({-1}, rewriter.getI8Type());
284312

285313

@@ -295,6 +323,12 @@ struct ExecuteOpLowering : public OpRewritePattern<sql::ExecuteOp> {
295323
Value res =
296324
rewriter.create<mlir::func::CallOp>(op.getLoc(), executefn, args)
297325
->getResult(0);
326+
res = rewriter.create<polygeist::Memref2PointerOp>(op.getLoc(),
327+
LLVM::LLVMPointerType::get(rewriter.getI8Type()), res);
328+
res = rewriter.create<LLVM::PtrToIntOp>(
329+
op.getLoc(), rewriter.getI64Type(), res);
330+
res = rewriter.create<arith::IndexCastOp>(op.getLoc(),
331+
op.getType(), res);
298332

299333
rewriter.replaceOp(op, res);
300334

0 commit comments

Comments
 (0)