Skip to content

Commit 32bc7d3

Browse files
committed
sql lowering workings with pragma
1 parent 7dae0f7 commit 32bc7d3

File tree

7 files changed

+244
-22
lines changed

7 files changed

+244
-22
lines changed

include/sql/Passes/Passes.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,13 @@ class RewritePatternSet;
1010
class DominanceInfo;
1111
namespace sql {
1212

13-
std::unique_ptr<Pass> createParallelLowerPass();
13+
std::unique_ptr<Pass> createSQLLowerPass();
14+
std::unique_ptr<Pass> createSQLRaisingPass();
1415
} // namespace sql
1516
} // namespace mlir
1617

18+
19+
1720
namespace mlir {
1821
// Forward declaration from Dialect.h
1922
template <typename ConcreteDialect>

include/sql/Passes/Passes.td

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,19 @@
44
include "mlir/Pass/PassBase.td"
55

66

7-
def ParallelLower : Pass<"sql-lower", "mlir::ModuleOp"> {
7+
def SQLLower : Pass<"sql-lower", "mlir::ModuleOp"> {
88
let summary = "Lower sql op to mlir";
99
let dependentDialects =
10-
["arith::AirthDialect", "func::FuncDialect", "LLVM::LLVMDialect"];
10+
["arith::ArithDialect", "func::FuncDialect", "LLVM::LLVMDialect"];
1111
let constructor = "mlir::sql::createSQLLowerPass()";
1212
}
1313

14+
15+
def SQLRaising : Pass<"sql-raising", "mlir::ModuleOp"> {
16+
let summary = "Raise sql op to mlir";
17+
let dependentDialects =
18+
["arith::ArithDialect", "func::FuncDialect", "LLVM::LLVMDialect"];
19+
let constructor = "mlir::sql::createSQLRaisingPass()";
20+
}
21+
1422
#endif // SQL_PASSES

lib/sql/Passes/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
add_mlir_dialect_library(MLIRSQLTransforms
22
SQLLower.cpp
3+
SQLRaising.cpp
34

45
DEPENDS
56
MLIRPolygeistOpsIncGen
67
MLIRPolygeistPassIncGen
8+
MLIRSQLPassIncGen
79

810
LINK_LIBS PUBLIC
911
MLIRArithDialect

lib/sql/Passes/PassDetails.h

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
//===- PassDetails.h - polygeist pass class details ----------------*- C++
2+
//-*-===//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
//
10+
// Stuff shared between the different polygeist passes.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
// clang-tidy seems to expect the absolute path in the header guard on some
15+
// systems, so just disable it.
16+
// NOLINTNEXTLINE(llvm-header-guard)
17+
#ifndef DIALECT_POLYGEIST_TRANSFORMS_PASSDETAILS_H
18+
#define DIALECT_POLYGEIST_TRANSFORMS_PASSDETAILS_H
19+
20+
#include "mlir/Pass/Pass.h"
21+
#include "sql/SQLOps.h"
22+
#include "sql/Passes/Passes.h"
23+
24+
namespace mlir {
25+
class FunctionOpInterface;
26+
// Forward declaration from Dialect.h
27+
template <typename ConcreteDialect>
28+
void registerDialect(DialectRegistry &registry);
29+
namespace sql {
30+
31+
class SQLDialect;
32+
33+
#define GEN_PASS_CLASSES
34+
#include "sql/Passes/Passes.h.inc"
35+
36+
} // namespace polygeist
37+
} // namespace mlir
38+
39+
#endif // DIALECT_POLYGEIST_TRANSFORMS_PASSDETAILS_H

lib/sql/Passes/SQLLower.cpp

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- SQLLower.cpp - Lower sql ops to mlir ------ -*-===//
1+
//===- SQLLower.cpp - Lower PostgreSQL to sql mlir ops ------ -*-===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -21,15 +21,17 @@
2121
#include "mlir/Transforms/Passes.h"
2222
#include "llvm/ADT/SetVector.h"
2323
#include "llvm/ADT/SmallPtrSet.h"
24+
#include "sql/SQLOps.h"
25+
#include "sql/Passes/Passes.h"
2426
#include <algorithm>
2527
#include <mutex>
2628

27-
#define DEBUG_TYPE "sql-opt"
29+
#define DEBUG_TYPE "sql-lower-opt"
2830

2931
using namespace mlir;
3032
using namespace mlir::arith;
3133
using namespace mlir::func;
32-
using namespace sql;
34+
using namespace mlir::sql;
3335

3436
namespace {
3537
struct SQLLower : public SQLLowerBase<SQLLower> {
@@ -45,26 +47,33 @@ struct NumResultsOpLowering : public OpRewritePattern<sql::NumResultsOp> {
4547
PatternRewriter &rewriter) const final {
4648
auto module = loop->getParentOfType<ModuleOp>();
4749

50+
SymbolTableCollection symbolTable;
51+
symbolTable.getSymbolTable(loop);
52+
4853
// 1) make sure the postgres_getresult function is declared
49-
auto rowsfn = dyn_cast_or_null<func::FuncOp>(symbolTable.lookupSymbolIn(
50-
module, builder.getStringAttr("PQcmdTuples")));
54+
auto rowsfn = dyn_cast_or_null<func::FuncOp>(
55+
symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("PQcmdTuples")));
5156

5257
auto atoifn = dyn_cast_or_null<func::FuncOp>(
53-
symbolTable.lookupSymbolIn(module, builder.getStringAttr("atoi")));
58+
symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("atoi")));
5459

5560
// 2) convert the args to valid args to postgres_getresult abi
5661
Value arg = loop.getHandle();
5762
arg = rewriter.create<arith::IndexCastOp>(loop.getLoc(),
58-
rewriter.getIntTy(64), arg);
63+
rewriter.getI64Type(), arg);
5964
arg = rewriter.create<LLVM::IntToPtrOp>(
60-
loop.getLoc(), LLVM::LLVMPointerType::get(builder.getInt8Ty()), arg);
65+
loop.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), arg);
6166

6267
// 3) call and replace
63-
Value args[] = {arg} Value res =
68+
Value args[] = {arg};
69+
70+
Value res =
6471
rewriter.create<mlir::func::CallOp>(loop.getLoc(), rowsfn, args)
6572
->getResult(0);
6673

67-
Value args2[] = {res} Value res2 =
74+
Value args2[] = {res};
75+
76+
Value res2 =
6877
rewriter.create<mlir::func::CallOp>(loop.getLoc(), atoifn, args2)
6978
->getResult(0);
7079

@@ -78,29 +87,32 @@ struct NumResultsOpLowering : public OpRewritePattern<sql::NumResultsOp> {
7887

7988
void SQLLower::runOnOperation() {
8089
auto module = getOperation();
90+
91+
SymbolTableCollection symbolTable;
92+
symbolTable.getSymbolTable(module);
8193
OpBuilder builder(module.getContext());
8294
builder.setInsertionPointToStart(module.getBody());
8395

8496
if (!dyn_cast_or_null<func::FuncOp>(symbolTable.lookupSymbolIn(
8597
module, builder.getStringAttr("PQcmdTuples")))) {
86-
mlir::Type argtypes[] = {LLVM::LLVMPointerType::get(builder.getInt8Ty())};
87-
mlir::Type rettypes[] = {LLVM::LLVMPointerType::get(builder.getInt8Ty())};
98+
mlir::Type argtypes[] = {LLVM::LLVMPointerType::get(builder.getI8Type())};
99+
mlir::Type rettypes[] = {LLVM::LLVMPointerType::get(builder.getI8Type())};
88100

89101
auto fn =
90102
builder.create<func::FuncOp>(module.getLoc(), "PQcmdTuples",
91-
builder.getFunctionType(argtys, rettys));
92-
SymbolTable::setSymbolVisibility(fn, SymbolTable::Private);
103+
builder.getFunctionType(argtypes, rettypes));
104+
SymbolTable::setSymbolVisibility(fn, SymbolTable::Visibility::Private);
93105
}
94106
if (!dyn_cast_or_null<func::FuncOp>(
95107
symbolTable.lookupSymbolIn(module, builder.getStringAttr("atoi")))) {
96-
mlir::Type argtypes[] = {LLVM::LLVMPointerType::get(builder.getInt8Ty())};
108+
mlir::Type argtypes[] = {LLVM::LLVMPointerType::get(builder.getI8Type())};
97109

98110
// todo use data layout
99-
mlir::Type rettypes[] = {builder.getIntTy(sizeof(int))};
111+
mlir::Type rettypes[] = {builder.getI64Type()};
100112

101113
auto fn = builder.create<func::FuncOp>(
102-
module.getLoc(), "atoi", builder.getFunctionType(argtys, rettys));
103-
SymbolTable::setSymbolVisibility(fn, SymbolTable::Private);
114+
module.getLoc(), "atoi", builder.getFunctionType(argtypes, rettypes));
115+
SymbolTable::setSymbolVisibility(fn, SymbolTable::Visibility::Private);
104116
}
105117

106118
RewritePatternSet patterns(&getContext());
@@ -112,7 +124,7 @@ void SQLLower::runOnOperation() {
112124
}
113125

114126
namespace mlir {
115-
namespace polygeist {
127+
namespace sql {
116128
std::unique_ptr<Pass> createSQLLowerPass() {
117129
return std::make_unique<SQLLower>();
118130
}

lib/sql/Passes/SQLRaising.cpp

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
//===- SQLLower.cpp - Lower PostgreSQL to sql mlir ops ------ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements a pass to lower gpu kernels in NVVM/gpu dialects into
10+
// a generic SQL for representation
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "PassDetails.h"
14+
#include "mlir/Analysis/CallGraph.h"
15+
#include "mlir/Dialect/Arith/IR/Arith.h"
16+
#include "mlir/Dialect/Async/IR/Async.h"
17+
#include "mlir/Dialect/Func/IR/FuncOps.h"
18+
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
19+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
20+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21+
#include "mlir/Transforms/Passes.h"
22+
#include "llvm/ADT/SetVector.h"
23+
#include "llvm/ADT/SmallPtrSet.h"
24+
#include "sql/SQLOps.h"
25+
#include "sql/Passes/Passes.h"
26+
#include <algorithm>
27+
#include <mutex>
28+
29+
#define DEBUG_TYPE "sql-raising-opt"
30+
31+
using namespace mlir;
32+
using namespace mlir::arith;
33+
using namespace mlir::func;
34+
using namespace mlir::sql;
35+
36+
namespace {
37+
struct SQLRaising : public SQLRaisingBase<SQLRaising> {
38+
void runOnOperation() override;
39+
};
40+
41+
} // end anonymous namespace
42+
43+
struct PQcmdTuplesRaising : public OpRewritePattern<func::CallOp> {
44+
using OpRewritePattern<func::CallOp>::OpRewritePattern;
45+
46+
LogicalResult matchAndRewrite(func::CallOp call,
47+
PatternRewriter &rewriter) const final {
48+
if (call.getCallee() != "PQcmdTuples") {
49+
return failure();
50+
}
51+
SymbolTableCollection symbolTable;
52+
symbolTable.getSymbolTable(call);
53+
auto module = call->getParentOfType<ModuleOp>();
54+
55+
// 2) convert the args to valid args to postgres_getresult abi
56+
Value arg = call.getArgOperands()[0];
57+
arg = rewriter.create<LLVM::PtrToIntOp>(
58+
call.getLoc(), rewriter.getIntegerType(64), arg);
59+
60+
arg = rewriter.create<arith::IndexCastOp>(call.getLoc(),
61+
rewriter.getIndexType(), arg);
62+
63+
Value res = rewriter.create<sql::NumResultsOp>(call.getLoc(), rewriter.getIndexType(), arg);
64+
65+
res = rewriter.create<arith::IndexCastOp>(call.getLoc(),
66+
rewriter.getI64Type(), res);
67+
68+
auto itoafn = dyn_cast_or_null<func::FuncOp>(
69+
symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("itoa")));
70+
71+
Value args2[] = {res};
72+
73+
rewriter.replaceOpWithNewOp<mlir::func::CallOp>(call, itoafn, args2);
74+
75+
// 4) done
76+
return success();
77+
}
78+
};
79+
80+
void SQLRaising::runOnOperation() {
81+
auto module = getOperation();
82+
SymbolTableCollection symbolTable;
83+
symbolTable.getSymbolTable(module);
84+
OpBuilder builder(module.getContext());
85+
builder.setInsertionPointToStart(module.getBody());
86+
87+
if (!dyn_cast_or_null<func::FuncOp>(
88+
symbolTable.lookupSymbolIn(module, builder.getStringAttr("itoa")))) {
89+
mlir::Type rettypes[] = {LLVM::LLVMPointerType::get(builder.getI8Type())};
90+
91+
// todo use data layout
92+
mlir::Type argtypes[] = {builder.getI64Type()};
93+
94+
auto fn = builder.create<func::FuncOp>(
95+
module.getLoc(), "itoa", builder.getFunctionType(argtypes, rettypes));
96+
SymbolTable::setSymbolVisibility(fn, SymbolTable::Visibility::Private);
97+
}
98+
99+
RewritePatternSet patterns(&getContext());
100+
patterns.insert<PQcmdTuplesRaising>(&getContext());
101+
102+
GreedyRewriteConfig config;
103+
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
104+
config);
105+
}
106+
107+
namespace mlir {
108+
namespace sql {
109+
std::unique_ptr<Pass> createSQLRaisingPass() {
110+
return std::make_unique<SQLRaising>();
111+
}
112+
} // namespace polygeist
113+
} // namespace mlir

test_with_pragma.c

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#include <stdio.h>
2+
#include <stdlib.h>
3+
#include <libpq-fe.h>
4+
5+
// PGresult *PQexec(PGconn*, const char* command);
6+
// PQgetvalue
7+
// %7 = call @PQexec(%2, %6) : (memref<?x1xi8>, memref<?xi8>) -> memref<?x1xi8>
8+
#pragma lower_to(num_rows_fn, "sql.num_results")
9+
int num_rows_fn(size_t);// char*
10+
11+
void do_exit(PGconn *conn) {
12+
13+
PQfinish(conn);
14+
exit(1);
15+
}
16+
17+
int main() {
18+
19+
PGconn *conn = PQconnectdb("user=janbodnar dbname=testdb");
20+
21+
if (PQstatus(conn) == CONNECTION_BAD) {
22+
23+
fprintf(stderr, "Connection to database failed: %s\n",
24+
PQerrorMessage(conn));
25+
do_exit(conn);
26+
}
27+
28+
PGresult *res = PQexec(conn, "SELECT VERSION()");
29+
30+
if (PQresultStatus(res) != PGRES_TUPLES_OK) {
31+
32+
printf("No data retrieved\n");
33+
PQclear(res);
34+
do_exit(conn);
35+
}
36+
37+
printf("%s\n", PQgetvalue(res, 0, 0));
38+
printf("%d\n", num_rows_fn((size_t)res));
39+
// res, 0, 0));
40+
41+
PQclear(res);
42+
PQfinish(conn);
43+
44+
return 0;
45+
}

0 commit comments

Comments
 (0)