Skip to content

Commit 7dae0f7

Browse files
committed
SQLLower
1 parent a324331 commit 7dae0f7

File tree

7 files changed

+345
-0
lines changed

7 files changed

+345
-0
lines changed

include/sql/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
add_mlir_dialect(SQLOps sql)
22
# add_mlir_doc(SQLDialect -gen-dialect-doc SQLDialect SQL/)
33
# add_mlir_doc(SQLOps -gen-op-doc SQLOps SQL/)
4+
5+
add_subdirectory(Passes)

include/sql/Passes/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
set(LLVM_TARGET_DEFINITIONS Passes.td)
2+
mlir_tablegen(Passes.h.inc -gen-pass-decls -name sql)
3+
add_public_tablegen_target(MLIRSQLPassIncGen)
4+
5+
add_mlir_doc(Passes SQLPasses ./ -gen-pass-doc)

include/sql/Passes/Passes.h

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#ifndef SQL_DIALECT_SQL_PASSES_H
2+
#define SQL_DIALECT_SQL_PASSES_H
3+
4+
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
5+
#include "mlir/Pass/Pass.h"
6+
#include <memory>
7+
namespace mlir {
8+
class PatternRewriter;
9+
class RewritePatternSet;
10+
class DominanceInfo;
11+
namespace sql {
12+
13+
std::unique_ptr<Pass> createParallelLowerPass();
14+
} // namespace sql
15+
} // namespace mlir
16+
17+
namespace mlir {
18+
// Forward declaration from Dialect.h
19+
template <typename ConcreteDialect>
20+
void registerDialect(DialectRegistry &registry);
21+
22+
namespace arith {
23+
class ArithDialect;
24+
} // end namespace arith
25+
26+
namespace scf {
27+
class SCFDialect;
28+
} // end namespace scf
29+
30+
namespace memref {
31+
class MemRefDialect;
32+
} // end namespace memref
33+
34+
namespace func {
35+
class FuncDialect;
36+
}
37+
38+
class AffineDialect;
39+
namespace LLVM {
40+
class LLVMDialect;
41+
}
42+
43+
#define GEN_PASS_REGISTRATION
44+
#include "sql/Passes/Passes.h.inc"
45+
46+
} // end namespace mlir
47+
48+
#endif // SQL_DIALECT_SQL_PASSES_H

include/sql/Passes/Passes.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#ifndef SQL_PASSES
2+
#define SQL_PASSES
3+
4+
include "mlir/Pass/PassBase.td"
5+
6+
7+
def ParallelLower : Pass<"sql-lower", "mlir::ModuleOp"> {
8+
let summary = "Lower sql op to mlir";
9+
let dependentDialects =
10+
["arith::AirthDialect", "func::FuncDialect", "LLVM::LLVMDialect"];
11+
let constructor = "mlir::sql::createSQLLowerPass()";
12+
}
13+
14+
#endif // SQL_PASSES

include/sql/Passes/Utils.h

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
#pragma once
2+
3+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
4+
#include "mlir/Dialect/SCF/IR/SCF.h"
5+
#include "mlir/IR/BlockAndValueMapping.h"
6+
#include "mlir/IR/IntegerSet.h"
7+
8+
static inline mlir::scf::IfOp
9+
cloneWithResults(mlir::scf::IfOp op, mlir::OpBuilder &rewriter,
10+
mlir::BlockAndValueMapping mapping = {}) {
11+
using namespace mlir;
12+
return rewriter.create<scf::IfOp>(op.getLoc(), op.getResultTypes(),
13+
mapping.lookupOrDefault(op.getCondition()),
14+
true);
15+
}
16+
static inline mlir::AffineIfOp
17+
cloneWithResults(mlir::AffineIfOp op, mlir::OpBuilder &rewriter,
18+
mlir::BlockAndValueMapping mapping = {}) {
19+
using namespace mlir;
20+
SmallVector<mlir::Value> lower;
21+
for (auto o : op.getOperands())
22+
lower.push_back(mapping.lookupOrDefault(o));
23+
return rewriter.create<AffineIfOp>(op.getLoc(), op.getResultTypes(),
24+
op.getIntegerSet(), lower, true);
25+
}
26+
27+
static inline mlir::scf::IfOp
28+
cloneWithoutResults(mlir::scf::IfOp op, mlir::OpBuilder &rewriter,
29+
mlir::BlockAndValueMapping mapping = {},
30+
mlir::TypeRange types = {}) {
31+
using namespace mlir;
32+
return rewriter.create<scf::IfOp>(
33+
op.getLoc(), types, mapping.lookupOrDefault(op.getCondition()), true);
34+
}
35+
static inline mlir::AffineIfOp
36+
cloneWithoutResults(mlir::AffineIfOp op, mlir::OpBuilder &rewriter,
37+
mlir::BlockAndValueMapping mapping = {},
38+
mlir::TypeRange types = {}) {
39+
using namespace mlir;
40+
SmallVector<mlir::Value> lower;
41+
for (auto o : op.getOperands())
42+
lower.push_back(mapping.lookupOrDefault(o));
43+
return rewriter.create<AffineIfOp>(op.getLoc(), types, op.getIntegerSet(),
44+
lower, true);
45+
}
46+
47+
static inline mlir::scf::ForOp
48+
cloneWithoutResults(mlir::scf::ForOp op, mlir::PatternRewriter &rewriter,
49+
mlir::BlockAndValueMapping mapping = {}) {
50+
using namespace mlir;
51+
return rewriter.create<scf::ForOp>(
52+
op.getLoc(), mapping.lookupOrDefault(op.getLowerBound()),
53+
mapping.lookupOrDefault(op.getUpperBound()),
54+
mapping.lookupOrDefault(op.getStep()));
55+
}
56+
static inline mlir::AffineForOp
57+
cloneWithoutResults(mlir::AffineForOp op, mlir::PatternRewriter &rewriter,
58+
mlir::BlockAndValueMapping mapping = {}) {
59+
using namespace mlir;
60+
SmallVector<Value> lower;
61+
for (auto o : op.getLowerBoundOperands())
62+
lower.push_back(mapping.lookupOrDefault(o));
63+
SmallVector<Value> upper;
64+
for (auto o : op.getUpperBoundOperands())
65+
upper.push_back(mapping.lookupOrDefault(o));
66+
return rewriter.create<AffineForOp>(op.getLoc(), lower, op.getLowerBoundMap(),
67+
upper, op.getUpperBoundMap(),
68+
op.getStep());
69+
}
70+
71+
static inline void clearBlock(mlir::Block *block,
72+
mlir::PatternRewriter &rewriter) {
73+
for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) {
74+
assert(op.use_empty() && "expected 'op' to have no uses");
75+
rewriter.eraseOp(&op);
76+
}
77+
}
78+
79+
static inline mlir::Block *getThenBlock(mlir::scf::IfOp op) {
80+
return op.thenBlock();
81+
}
82+
static inline mlir::Block *getThenBlock(mlir::AffineIfOp op) {
83+
return op.getThenBlock();
84+
}
85+
static inline mlir::Block *getElseBlock(mlir::scf::IfOp op) {
86+
return op.elseBlock();
87+
}
88+
static inline mlir::Block *getElseBlock(mlir::AffineIfOp op) {
89+
if (op.hasElse())
90+
return op.getElseBlock();
91+
else
92+
return nullptr;
93+
}
94+
95+
static inline mlir::Region &getThenRegion(mlir::scf::IfOp op) {
96+
return op.getThenRegion();
97+
}
98+
static inline mlir::Region &getThenRegion(mlir::AffineIfOp op) {
99+
return op.getThenRegion();
100+
}
101+
static inline mlir::Region &getElseRegion(mlir::scf::IfOp op) {
102+
return op.getElseRegion();
103+
}
104+
static inline mlir::Region &getElseRegion(mlir::AffineIfOp op) {
105+
return op.getElseRegion();
106+
}
107+
108+
static inline mlir::scf::YieldOp getThenYield(mlir::scf::IfOp op) {
109+
return op.thenYield();
110+
}
111+
static inline mlir::AffineYieldOp getThenYield(mlir::AffineIfOp op) {
112+
return llvm::cast<mlir::AffineYieldOp>(op.getThenBlock()->getTerminator());
113+
}
114+
static inline mlir::scf::YieldOp getElseYield(mlir::scf::IfOp op) {
115+
return op.elseYield();
116+
}
117+
static inline mlir::AffineYieldOp getElseYield(mlir::AffineIfOp op) {
118+
return llvm::cast<mlir::AffineYieldOp>(op.getElseBlock()->getTerminator());
119+
}
120+
121+
static inline bool inBound(mlir::scf::IfOp op, mlir::Value v) {
122+
return op.getCondition() == v;
123+
}
124+
static inline bool inBound(mlir::AffineIfOp op, mlir::Value v) {
125+
return llvm::any_of(op.getOperands(), [&](mlir::Value e) { return e == v; });
126+
}
127+
static inline bool inBound(mlir::scf::ForOp op, mlir::Value v) {
128+
return op.getUpperBound() == v;
129+
}
130+
static inline bool inBound(mlir::AffineForOp op, mlir::Value v) {
131+
return llvm::any_of(op.getUpperBoundOperands(),
132+
[&](mlir::Value e) { return e == v; });
133+
}
134+
static inline bool hasElse(mlir::scf::IfOp op) {
135+
return op.getElseRegion().getBlocks().size() > 0;
136+
}
137+
static inline bool hasElse(mlir::AffineIfOp op) {
138+
return op.getElseRegion().getBlocks().size() > 0;
139+
}

lib/sql/Passes/CMakeLists.txt

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
add_mlir_dialect_library(MLIRSQLTransforms
2+
SQLLower.cpp
3+
4+
DEPENDS
5+
MLIRPolygeistOpsIncGen
6+
MLIRPolygeistPassIncGen
7+
8+
LINK_LIBS PUBLIC
9+
MLIRArithDialect
10+
MLIRFuncDialect
11+
MLIRFuncTransforms
12+
MLIRIR
13+
MLIRLLVMDialect
14+
MLIRMathDialect
15+
MLIRMemRefDialect
16+
MLIRPass
17+
)

lib/sql/Passes/SQLLower.cpp

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
//===- SQLLower.cpp - Lower sql ops to mlir ------ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements a pass to lower gpu kernels in NVVM/gpu dialects into
10+
// a generic SQL for representation
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "PassDetails.h"
14+
#include "mlir/Analysis/CallGraph.h"
15+
#include "mlir/Dialect/Arith/IR/Arith.h"
16+
#include "mlir/Dialect/Async/IR/Async.h"
17+
#include "mlir/Dialect/Func/IR/FuncOps.h"
18+
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
19+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
20+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21+
#include "mlir/Transforms/Passes.h"
22+
#include "llvm/ADT/SetVector.h"
23+
#include "llvm/ADT/SmallPtrSet.h"
24+
#include <algorithm>
25+
#include <mutex>
26+
27+
#define DEBUG_TYPE "sql-opt"
28+
29+
using namespace mlir;
30+
using namespace mlir::arith;
31+
using namespace mlir::func;
32+
using namespace sql;
33+
34+
namespace {
35+
struct SQLLower : public SQLLowerBase<SQLLower> {
36+
void runOnOperation() override;
37+
};
38+
39+
} // end anonymous namespace
40+
41+
struct NumResultsOpLowering : public OpRewritePattern<sql::NumResultsOp> {
42+
using OpRewritePattern<sql::NumResultsOp>::OpRewritePattern;
43+
44+
LogicalResult matchAndRewrite(sql::NumResultsOp loop,
45+
PatternRewriter &rewriter) const final {
46+
auto module = loop->getParentOfType<ModuleOp>();
47+
48+
// 1) make sure the postgres_getresult function is declared
49+
auto rowsfn = dyn_cast_or_null<func::FuncOp>(symbolTable.lookupSymbolIn(
50+
module, builder.getStringAttr("PQcmdTuples")));
51+
52+
auto atoifn = dyn_cast_or_null<func::FuncOp>(
53+
symbolTable.lookupSymbolIn(module, builder.getStringAttr("atoi")));
54+
55+
// 2) convert the args to valid args to postgres_getresult abi
56+
Value arg = loop.getHandle();
57+
arg = rewriter.create<arith::IndexCastOp>(loop.getLoc(),
58+
rewriter.getIntTy(64), arg);
59+
arg = rewriter.create<LLVM::IntToPtrOp>(
60+
loop.getLoc(), LLVM::LLVMPointerType::get(builder.getInt8Ty()), arg);
61+
62+
// 3) call and replace
63+
Value args[] = {arg} Value res =
64+
rewriter.create<mlir::func::CallOp>(loop.getLoc(), rowsfn, args)
65+
->getResult(0);
66+
67+
Value args2[] = {res} Value res2 =
68+
rewriter.create<mlir::func::CallOp>(loop.getLoc(), atoifn, args2)
69+
->getResult(0);
70+
71+
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
72+
loop, rewriter.getIndexType(), res2);
73+
74+
// 4) done
75+
return success();
76+
}
77+
};
78+
79+
void SQLLower::runOnOperation() {
80+
auto module = getOperation();
81+
OpBuilder builder(module.getContext());
82+
builder.setInsertionPointToStart(module.getBody());
83+
84+
if (!dyn_cast_or_null<func::FuncOp>(symbolTable.lookupSymbolIn(
85+
module, builder.getStringAttr("PQcmdTuples")))) {
86+
mlir::Type argtypes[] = {LLVM::LLVMPointerType::get(builder.getInt8Ty())};
87+
mlir::Type rettypes[] = {LLVM::LLVMPointerType::get(builder.getInt8Ty())};
88+
89+
auto fn =
90+
builder.create<func::FuncOp>(module.getLoc(), "PQcmdTuples",
91+
builder.getFunctionType(argtys, rettys));
92+
SymbolTable::setSymbolVisibility(fn, SymbolTable::Private);
93+
}
94+
if (!dyn_cast_or_null<func::FuncOp>(
95+
symbolTable.lookupSymbolIn(module, builder.getStringAttr("atoi")))) {
96+
mlir::Type argtypes[] = {LLVM::LLVMPointerType::get(builder.getInt8Ty())};
97+
98+
// todo use data layout
99+
mlir::Type rettypes[] = {builder.getIntTy(sizeof(int))};
100+
101+
auto fn = builder.create<func::FuncOp>(
102+
module.getLoc(), "atoi", builder.getFunctionType(argtys, rettys));
103+
SymbolTable::setSymbolVisibility(fn, SymbolTable::Private);
104+
}
105+
106+
RewritePatternSet patterns(&getContext());
107+
patterns.insert<NumResultsOpLowering>(&getContext());
108+
109+
GreedyRewriteConfig config;
110+
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
111+
config);
112+
}
113+
114+
namespace mlir {
115+
namespace polygeist {
116+
std::unique_ptr<Pass> createSQLLowerPass() {
117+
return std::make_unique<SQLLower>();
118+
}
119+
} // namespace polygeist
120+
} // namespace mlir

0 commit comments

Comments
 (0)