Skip to content

Commit 5bed072

Browse files
committed
select op
1 parent 548d94a commit 5bed072

File tree

6 files changed

+286
-186
lines changed

6 files changed

+286
-186
lines changed

include/sql/SQLOps.td

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,27 +12,67 @@
1212
include "mlir/IR/AttrTypeBase.td"
1313
include "SQLDialect.td"
1414

15+
def IntOp : SQL_Op<"int", [Pure]> {
16+
let summary = "select";
17+
18+
let arguments = (ins StrAttr:$expr);
19+
let results = (outs SQLExprType:$result);
20+
21+
let hasFolder = 0;
22+
let hasCanonicalizer = 0;
23+
}
24+
25+
def ColumnOp : SQL_Op<"column", [Pure]> {
26+
let summary = "select";
27+
28+
let arguments = (ins StrAttr:$expr);
29+
let results = (outs SQLExprType:$result);
30+
31+
let hasFolder = 0;
32+
let hasCanonicalizer = 0;
33+
}
34+
35+
def TableOp : SQL_Op<"table", [Pure]> {
36+
let summary = "select";
37+
38+
let arguments = (ins StrAttr:$expr);
39+
let results = (outs SQLExprType:$result);
40+
41+
let hasFolder = 0;
42+
let hasCanonicalizer = 0;
43+
}
44+
1545
def SelectOp : SQL_Op<"select", [Pure]> {
1646
let summary = "select";
1747

18-
// TODO: limit (optional), where clauses, join, etc
19-
let arguments = (ins StrArrayAttr:$column, StrAttr:$table);
20-
let results = (outs Index : $result);
48+
let arguments = (ins Variadic<SQLExprType>:$columns, Optional<SQLExprType>:$table);
49+
let results = (outs SQLExprType:$result);
2150

2251
let hasFolder = 0;
2352
let hasCanonicalizer = 0;
2453
}
2554

55+
def UnparsedOp : SQL_Op<"unparsed", [Pure]> {
56+
let summary = "unparsed sql op";
57+
58+
let arguments = (ins AnyType:$input);
59+
let results = (outs SQLExprType:$result);
60+
61+
let hasFolder = 0;
62+
let hasCanonicalizer = 1;
63+
}
64+
2665
def ExecuteOp : SQL_Op<"execute", []> {
2766
let summary = "execute query";
2867

29-
let arguments = (ins Index:$conn, Index:$command);
68+
let arguments = (ins Index:$conn, SQLExprType:$command);
3069
let results = (outs Index:$result);
3170

3271
let hasFolder = 0;
33-
let hasCanonicalizer = 0;
72+
let hasCanonicalizer = 1;
3473
}
3574

75+
3676
def NumResultsOp : SQL_Op<"num_results", [Pure]> {
3777
let summary = "number of results";
3878

@@ -52,5 +92,18 @@ def GetValueOp : SQL_Op<"get_value", [Pure]> {
5292
let hasFolder = 0;
5393
let hasCanonicalizer = 1;
5494
}
55-
95+
96+
// def SelectOp : SQL_Op<"select", [Pure]>{
97+
// let summary = "select";
98+
99+
// let arguments = (ins StrArrayAttr:$columns,
100+
// Optional<AnyMemRef>:$from);
101+
// // optional<list<clauses>>:$where,
102+
// // optional<int>:$limit,
103+
// // optional<int>:$order);
104+
// let results = (outs AnyType:$result);
105+
106+
// let hasFolder = 0;
107+
// let hasCanonicalizer = 0;
108+
// }
56109
#endif // SQL_OPS

include/sql/SQLTypes.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
//===- SQLOps.h - SQL dialect ops --------------------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef SQLTYPES_H
10+
#define SQLTYPES_H
11+
12+
#include "mlir/IR/BuiltinTypes.h"
13+
#include "mlir/IR/Dialect.h"
14+
#include "mlir/IR/OpDefinition.h"
15+
#include "mlir/IR/PatternMatch.h"
16+
#include "mlir/Interfaces/SideEffectInterfaces.h"
17+
#include "mlir/Interfaces/ViewLikeInterface.h"
18+
#include "llvm/Support/CommandLine.h"
19+
20+
#define GET_TYPE_CLASSES
21+
#include "sql/SQLTypes.h.inc"
22+
23+
#endif

include/sql/SQLTypes.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
//===- SQLTypes.td - SQL dialect types ----------------*- tablegen -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef SQL_TYPES
10+
#define SQL_TYPES
11+
12+
include "mlir/IR/AttrTypeBase.td"
13+
include "SQLDialect.td"
14+
15+
def SQLExprType : SQL_Type<"Expr", "expr"> {
16+
let summary = "SQL expression type";
17+
}
18+
19+
20+
#endif // SQL_TYPES

lib/sql/Ops.cpp

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "sql/SQLDialect.h"
1818
#include "sql/SQLOps.h"
19+
#include "polygeist/Ops.h"
1920

2021
#define GET_OP_CLASSES
2122
#include "sql/SQLOps.cpp.inc"
@@ -30,7 +31,6 @@
3031
#include "mlir/IR/Dominance.h"
3132
#include "mlir/IR/IntegerSet.h"
3233
#include "mlir/Transforms/SideEffectUtils.h"
33-
// #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
3434

3535
#include "llvm/ADT/SetVector.h"
3636
#include "llvm/Support/Debug.h"
@@ -52,7 +52,6 @@ class GetValueOpTypeFix final : public OpRewritePattern<GetValueOp> {
5252
bool changed = false;
5353

5454
Value handle = op.getOperand(0);
55-
5655
if (!handle.getType().isa<IndexType>()) {
5756
handle = rewriter.create<IndexCastOp>(op.getLoc(),
5857
rewriter.getIndexType(), handle);
@@ -119,4 +118,58 @@ class NumResultsOpTypeFix final : public OpRewritePattern<NumResultsOp> {
119118
void NumResultsOp::getCanonicalizationPatterns(RewritePatternSet &results,
120119
MLIRContext *context) {
121120
results.insert<NumResultsOpTypeFix>(context);
121+
}
122+
123+
124+
125+
class ExecuteOpTypeFix final : public OpRewritePattern<ExecuteOp> {
126+
public:
127+
using OpRewritePattern<ExecuteOp>::OpRewritePattern;
128+
129+
LogicalResult matchAndRewrite(ExecuteOp op,
130+
PatternRewriter &rewriter) const override {
131+
bool changed = false;
132+
133+
Value conn = op->getOperand(0);
134+
Value command = op->getOperand(1);
135+
136+
if (conn.getType().isa<IndexType>() && command.getType().isa<IndexType>() && op->getResultTypes()[0].isa<IndexType>())
137+
return failure();
138+
139+
if (!conn.getType().isa<IndexType>()) {
140+
conn = rewriter.create<IndexCastOp>(op.getLoc(),
141+
rewriter.getIndexType(), conn);
142+
changed = true;
143+
}
144+
if (command.getType().isa<MemRefType>()) {
145+
command = rewriter.create<polygeist::Memref2PointerOp>(op.getLoc(),
146+
LLVM::LLVMPointerType::get(rewriter.getI8Type()), command);
147+
changed = true;
148+
}
149+
if (command.getType().isa<LLVM::LLVMPointerType>()) {
150+
command = rewriter.create<LLVM::PtrToIntOp>(op.getLoc(),
151+
rewriter.getI64Type(), command);
152+
changed = true;
153+
}
154+
if (!command.getType().isa<IndexType>()) {
155+
command = rewriter.create<IndexCastOp>(op.getLoc(),
156+
rewriter.getIndexType(), command);
157+
changed = true;
158+
}
159+
160+
if (!changed) return failure();
161+
mlir::Value res = rewriter.create<ExecuteOp>(op.getLoc(), rewriter.getIndexType(), conn, command);
162+
rewriter.replaceOp(op, res);
163+
// if (op->getResultTypes()[0].isa<IndexType>()) {
164+
// rewriter.replaceOp(op, res);
165+
// } else {
166+
// rewriter.replaceOpWithNewOp<IndexCastOp>(op, op->getResultTypes()[0], res);
167+
// }
168+
return success(changed);
169+
}
170+
};
171+
172+
void ExecuteOp::getCanonicalizationPatterns(RewritePatternSet &results,
173+
MLIRContext *context) {
174+
results.insert<ExecuteOpTypeFix>(context);
122175
}

0 commit comments

Comments
 (0)