Skip to content

Commit d8b6e4f

Browse files
authored
Fix atomic load (#312)
* Fix atomic load * Fix addrspace conversion * Parallel lower fixup
1 parent c235e45 commit d8b6e4f

File tree

8 files changed

+425
-13
lines changed

8 files changed

+425
-13
lines changed

include/polygeist/PolygeistOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def GetFuncOp : Polygeist_Op<"get_func",
118118
let arguments = (ins FlatSymbolRefAttr:$name);
119119
let results = (outs LLVM_AnyPointer : $result);
120120
let assemblyFormat = "$name `:` type($result) attr-dict";
121+
let hasCanonicalizer = 1;
121122
}
122123

123124
def TrivialUseOp : Polygeist_Op<"trivialuse"> {

lib/polygeist/Ops.cpp

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5432,6 +5432,96 @@ void TypeAlignOp::getCanonicalizationPatterns(RewritePatternSet &results,
54325432
// GetFuncOp
54335433
//===----------------------------------------------------------------------===//
54345434

5435+
LogicalResult fixupGetFunc(LLVM::CallOp op, OpBuilder &rewriter,
5436+
SmallVectorImpl<Value> &vals) {
5437+
if (op.getCallee())
5438+
return failure();
5439+
5440+
Value pval = op.getOperand(0);
5441+
5442+
auto FT = pval.getType()
5443+
.cast<LLVM::LLVMPointerType>()
5444+
.getElementType()
5445+
.cast<LLVM::LLVMFunctionType>();
5446+
if (FT.isVarArg())
5447+
return failure();
5448+
5449+
while (true) {
5450+
if (auto bc = pval.getDefiningOp<LLVM::BitcastOp>())
5451+
pval = bc.getOperand();
5452+
else if (auto mt = pval.getDefiningOp<Memref2PointerOp>())
5453+
pval = mt.getOperand();
5454+
else if (auto mt = pval.getDefiningOp<Pointer2MemrefOp>())
5455+
pval = mt.getOperand();
5456+
else
5457+
break;
5458+
}
5459+
5460+
LLVM::LLVMFunctionType FT2;
5461+
if (auto MT = pval.getType().dyn_cast<MemRefType>())
5462+
FT2 = MT.getElementType().cast<LLVM::LLVMFunctionType>();
5463+
else
5464+
FT2 = pval.getType()
5465+
.cast<LLVM::LLVMPointerType>()
5466+
.getElementType()
5467+
.cast<LLVM::LLVMFunctionType>();
5468+
5469+
if (FT2.getParams().size() != FT.getParams().size())
5470+
return failure();
5471+
5472+
auto gfn = pval.getDefiningOp<GetFuncOp>();
5473+
if (!gfn)
5474+
return failure();
5475+
SmallVector<Value> args(op.getOperands());
5476+
args.erase(args.begin());
5477+
for (int i = 0; i < args.size(); i++) {
5478+
if (FT2.getParams()[i] != args[i].getType()) {
5479+
if (!FT2.getParams()[i].isa<MemRefType>() ||
5480+
!args[i].getType().isa<LLVM::LLVMPointerType>())
5481+
return failure();
5482+
args[i] = rewriter.create<polygeist::Pointer2MemrefOp>(
5483+
op.getLoc(), FT2.getParams()[i], args[i]);
5484+
}
5485+
}
5486+
5487+
if (op.getResultTypes().size() &&
5488+
(!op.getResultTypes()[0].isa<LLVM::LLVMPointerType>() ||
5489+
!FT2.getReturnType().isa<MemRefType>()))
5490+
return failure();
5491+
5492+
auto res = rewriter
5493+
.create<func::CallOp>(op.getLoc(), gfn.getNameAttr(),
5494+
op.getResultTypes(), args)
5495+
.getResults();
5496+
for (Value r : res) {
5497+
if (r.getType() != FT.getReturnType())
5498+
r = rewriter.create<polygeist::Memref2PointerOp>(op.getLoc(),
5499+
FT.getReturnType(), r);
5500+
vals.push_back(r);
5501+
}
5502+
return success();
5503+
}
5504+
5505+
class GetFuncFix final : public OpRewritePattern<LLVM::CallOp> {
5506+
public:
5507+
using OpRewritePattern<LLVM::CallOp>::OpRewritePattern;
5508+
5509+
LogicalResult matchAndRewrite(LLVM::CallOp op,
5510+
PatternRewriter &rewriter) const override {
5511+
SmallVector<Value> vals;
5512+
if (fixupGetFunc(op, rewriter, vals).failed())
5513+
return failure();
5514+
rewriter.replaceOp(op, vals);
5515+
5516+
return success();
5517+
}
5518+
};
5519+
5520+
void GetFuncOp::getCanonicalizationPatterns(RewritePatternSet &results,
5521+
MLIRContext *context) {
5522+
results.insert<GetFuncFix>(context);
5523+
}
5524+
54355525
LogicalResult GetFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
54365526
// TODO: Verify that the result type is same as the type of the referenced
54375527
// func.func op.

lib/polygeist/Passes/ConvertPolygeistToLLVM.cpp

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,14 @@ struct Memref2PointerOpLowering
152152
ConversionPatternRewriter &rewriter) const override {
153153
auto loc = op.getLoc();
154154

155+
auto LPT = op.getType().cast<LLVM::LLVMPointerType>();
156+
auto space0 = op.getSource().getType().getMemorySpaceAsInt();
155157
if (transformed.getSource().getType().isa<LLVM::LLVMPointerType>()) {
156-
auto ptr = rewriter.create<LLVM::BitcastOp>(loc, op.getType(),
157-
transformed.getSource());
158+
mlir::Value ptr = rewriter.create<LLVM::BitcastOp>(
159+
loc, LLVM::LLVMPointerType::get(LPT.getElementType(), space0),
160+
transformed.getSource());
161+
if (space0 != LPT.getAddressSpace())
162+
ptr = rewriter.create<LLVM::AddrSpaceCastOp>(loc, LPT, ptr);
158163
rewriter.replaceOp(op, {ptr});
159164
return success();
160165
}
@@ -169,7 +174,10 @@ struct Memref2PointerOpLowering
169174
Value ptr = targetMemRef.alignedPtr(rewriter, loc);
170175
Value idxs[] = {baseOffset};
171176
ptr = rewriter.create<LLVM::GEPOp>(loc, ptr.getType(), ptr, idxs);
172-
ptr = rewriter.create<LLVM::BitcastOp>(loc, op.getType(), ptr);
177+
ptr = rewriter.create<LLVM::BitcastOp>(
178+
loc, LLVM::LLVMPointerType::get(LPT.getElementType(), space0), ptr);
179+
if (space0 != LPT.getAddressSpace())
180+
ptr = rewriter.create<LLVM::AddrSpaceCastOp>(loc, LPT, ptr);
173181

174182
rewriter.replaceOp(op, {ptr});
175183
return success();
@@ -997,6 +1005,25 @@ struct CLoadOpLowering : public CLoadStoreOpLowering<memref::LoadOp> {
9971005
}
9981006
};
9991007

1008+
struct CAtomicRMWOpLowering : public CLoadStoreOpLowering<memref::AtomicRMWOp> {
1009+
using CLoadStoreOpLowering<memref::AtomicRMWOp>::CLoadStoreOpLowering;
1010+
1011+
LogicalResult
1012+
matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1013+
ConversionPatternRewriter &rewriter) const override {
1014+
auto maybeKind = matchSimpleAtomicOp(atomicOp);
1015+
if (!maybeKind)
1016+
return failure();
1017+
auto dataPtr = getAddress(atomicOp, adaptor, rewriter);
1018+
if (!dataPtr)
1019+
return failure();
1020+
rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
1021+
atomicOp, atomicOp.getType(), *maybeKind, dataPtr, adaptor.getValue(),
1022+
LLVM::AtomicOrdering::acq_rel);
1023+
return success();
1024+
}
1025+
};
1026+
10001027
/// Pattern for lowering a memory store.
10011028
struct CStoreOpLowering : public CLoadStoreOpLowering<memref::StoreOp> {
10021029
public:
@@ -1284,7 +1311,8 @@ populateCStyleMemRefLoweringPatterns(RewritePatternSet &patterns,
12841311
LLVMTypeConverter &typeConverter) {
12851312
patterns.add<CAllocaOpLowering, CAllocOpLowering, CDeallocOpLowering,
12861313
GetGlobalOpLowering, GlobalOpLowering, CLoadOpLowering,
1287-
CStoreOpLowering, AllocaScopeOpLowering>(typeConverter);
1314+
CStoreOpLowering, AllocaScopeOpLowering, CAtomicRMWOpLowering>(
1315+
typeConverter);
12881316
}
12891317

12901318
/// Appends the patterns lowering operations from the Func dialect to the LLVM

lib/polygeist/Passes/ParallelLower.cpp

Lines changed: 131 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "mlir/Transforms/Passes.h"
2828
#include "polygeist/Ops.h"
2929
#include "polygeist/Passes/Passes.h"
30+
#include "llvm/ADT/SetVector.h"
3031
#include "llvm/ADT/SmallPtrSet.h"
3132
#include <algorithm>
3233
#include <mutex>
@@ -198,18 +199,23 @@ mlir::LLVM::LLVMFuncOp GetOrCreateFreeFunction(ModuleOp module) {
198199
lnk);
199200
}
200201

202+
LogicalResult fixupGetFunc(LLVM::CallOp, OpBuilder &rewriter,
203+
SmallVectorImpl<Value> &);
204+
201205
void ParallelLower::runOnOperation() {
202206
// The inliner should only be run on operations that define a symbol table,
203207
// as the callgraph will need to resolve references.
204208

205209
SymbolTableCollection symbolTable;
206210
symbolTable.getSymbolTable(getOperation());
211+
SymbolUserMap symbolUserMap(symbolTable, getOperation());
207212

208213
getOperation()->walk([&](CallOp bidx) {
209214
if (bidx.getCallee() == "cudaThreadSynchronize")
210215
bidx.erase();
211216
});
212217

218+
std::function<void(LLVM::CallOp)> LLVMcallInliner;
213219
std::function<void(CallOp)> callInliner = [&](CallOp caller) {
214220
// Build the inliner interface.
215221
AlwaysInlinerInterface interface(&getContext());
@@ -230,10 +236,72 @@ void ParallelLower::runOnOperation() {
230236
return;
231237
if (targetRegion->empty())
232238
return;
233-
SmallVector<CallOp> ops;
234-
callableOp.walk([&](CallOp caller) { ops.push_back(caller); });
235-
for (auto op : ops)
236-
callInliner(op);
239+
{
240+
SmallVector<CallOp> ops;
241+
callableOp.walk([&](CallOp caller) { ops.push_back(caller); });
242+
for (auto op : ops)
243+
callInliner(op);
244+
}
245+
{
246+
SmallVector<LLVM::CallOp> ops;
247+
callableOp.walk([&](LLVM::CallOp caller) { ops.push_back(caller); });
248+
for (auto op : ops)
249+
LLVMcallInliner(op);
250+
}
251+
OpBuilder b(caller);
252+
auto allocScope = b.create<memref::AllocaScopeOp>(caller.getLoc(),
253+
caller.getResultTypes());
254+
allocScope.getRegion().push_back(new Block());
255+
b.setInsertionPointToStart(&allocScope.getRegion().front());
256+
auto exOp = b.create<scf::ExecuteRegionOp>(caller.getLoc(),
257+
caller.getResultTypes());
258+
Block *blk = new Block();
259+
exOp.getRegion().push_back(blk);
260+
caller->moveBefore(blk, blk->begin());
261+
caller.replaceAllUsesWith(allocScope.getResults());
262+
b.setInsertionPointToEnd(blk);
263+
b.create<scf::YieldOp>(caller.getLoc(), caller.getResults());
264+
if (inlineCall(interface, caller, callableOp, targetRegion,
265+
/*shouldCloneInlinedRegion=*/true)
266+
.succeeded()) {
267+
caller.erase();
268+
}
269+
b.setInsertionPointToEnd(&allocScope.getRegion().front());
270+
b.create<memref::AllocaScopeReturnOp>(allocScope.getLoc(),
271+
exOp.getResults());
272+
};
273+
LLVMcallInliner = [&](LLVM::CallOp caller) {
274+
// Build the inliner interface.
275+
AlwaysInlinerInterface interface(&getContext());
276+
277+
auto callable = caller.getCallableForCallee();
278+
CallableOpInterface callableOp;
279+
if (SymbolRefAttr symRef = callable.dyn_cast<SymbolRefAttr>()) {
280+
if (!symRef.isa<FlatSymbolRefAttr>())
281+
return;
282+
auto *symbolOp =
283+
symbolTable.lookupNearestSymbolFrom(getOperation(), symRef);
284+
callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp);
285+
} else {
286+
return;
287+
}
288+
Region *targetRegion = callableOp.getCallableRegion();
289+
if (!targetRegion)
290+
return;
291+
if (targetRegion->empty())
292+
return;
293+
{
294+
SmallVector<CallOp> ops;
295+
callableOp.walk([&](CallOp caller) { ops.push_back(caller); });
296+
for (auto op : ops)
297+
callInliner(op);
298+
}
299+
{
300+
SmallVector<LLVM::CallOp> ops;
301+
callableOp.walk([&](LLVM::CallOp caller) { ops.push_back(caller); });
302+
for (auto op : ops)
303+
LLVMcallInliner(op);
304+
}
237305
OpBuilder b(caller);
238306
auto allocScope = b.create<memref::AllocaScopeOp>(caller.getLoc(),
239307
caller.getResultTypes());
@@ -256,6 +324,7 @@ void ParallelLower::runOnOperation() {
256324
b.create<memref::AllocaScopeReturnOp>(allocScope.getLoc(),
257325
exOp.getResults());
258326
};
327+
259328
{
260329
SmallVector<CallOp> dimsToInline;
261330
getOperation()->walk([&](CallOp bidx) {
@@ -268,15 +337,68 @@ void ParallelLower::runOnOperation() {
268337
}
269338

270339
// Only supports single block functions at the moment.
340+
341+
SmallVector<std::pair<Operation *, size_t>> outlineOps;
342+
getOperation().walk([&](gpu::LaunchOp launchOp) {
343+
launchOp.walk([&](LLVM::CallOp caller) {
344+
if (!caller.getCallee()) {
345+
outlineOps.push_back(std::make_pair(caller, (size_t)0));
346+
}
347+
});
348+
});
349+
SetVector<FunctionOpInterface> toinl;
350+
while (outlineOps.size()) {
351+
auto opv = outlineOps.back();
352+
auto op = std::get<0>(opv);
353+
auto idx = std::get<1>(opv);
354+
outlineOps.pop_back();
355+
if (Value fn = op->getOperand(idx)) {
356+
if (auto fn2 = fn.getDefiningOp<polygeist::Memref2PointerOp>())
357+
fn = fn2.getOperand();
358+
if (auto ba = fn.dyn_cast<BlockArgument>()) {
359+
if (auto F =
360+
dyn_cast<FunctionOpInterface>(ba.getOwner()->getParentOp())) {
361+
if (toinl.count(F))
362+
continue;
363+
toinl.insert(F);
364+
for (Operation *m : symbolUserMap.getUsers(F)) {
365+
outlineOps.push_back(std::make_pair(m, (size_t)ba.getArgNumber()));
366+
}
367+
}
368+
}
369+
}
370+
}
371+
for (auto F : toinl) {
372+
for (Operation *m : symbolUserMap.getUsers(F)) {
373+
callInliner(cast<CallOp>(m));
374+
}
375+
}
376+
getOperation().walk([&](LLVM::CallOp caller) {
377+
OpBuilder builder(caller);
378+
SmallVector<Value> vals;
379+
if (fixupGetFunc(caller, builder, vals).failed())
380+
return;
381+
if (vals.size())
382+
caller.getResult().replaceAllUsesWith(vals[0]);
383+
caller.erase();
384+
});
385+
271386
SmallVector<gpu::LaunchOp> toHandle;
272387
getOperation().walk(
273388
[&](gpu::LaunchOp launchOp) { toHandle.push_back(launchOp); });
274-
275389
for (gpu::LaunchOp launchOp : toHandle) {
276-
SmallVector<CallOp> ops;
277-
launchOp.walk([&](CallOp caller) { ops.push_back(caller); });
278-
for (auto op : ops)
279-
callInliner(op);
390+
{
391+
SmallVector<CallOp> ops;
392+
launchOp.walk([&](CallOp caller) { ops.push_back(caller); });
393+
for (auto op : ops)
394+
callInliner(op);
395+
}
396+
{
397+
SmallVector<LLVM::CallOp> lops;
398+
launchOp.walk([&](LLVM::CallOp caller) { lops.push_back(caller); });
399+
for (auto op : lops)
400+
LLVMcallInliner(op);
401+
}
280402

281403
mlir::IRRewriter builder(launchOp.getContext());
282404
auto loc = launchOp.getLoc();

0 commit comments

Comments
 (0)