Skip to content

Commit 1621eda

Browse files
wsmosesivanradanov
andauthored
Fix and add test for async lowering (#301)
* Fix and add test for async lowering * generalize test Co-authored-by: Ivan Radanov Ivanov <[email protected]>
1 parent 4370bba commit 1621eda

File tree

2 files changed

+94
-8
lines changed

2 files changed

+94
-8
lines changed

lib/polygeist/Passes/ConvertPolygeistToLLVM.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,6 @@
4040
using namespace mlir;
4141
using namespace polygeist;
4242

43-
mlir::Value callMalloc(mlir::OpBuilder &builder, mlir::ModuleOp module,
44-
mlir::Location loc, mlir::Value arg);
4543
mlir::LLVM::LLVMFuncOp GetOrCreateFreeFunction(ModuleOp module);
4644

4745
/// Conversion pattern that transforms a subview op into:
@@ -563,9 +561,11 @@ struct AsyncOpLowering : public ConvertOpToLLVMPattern<async::ExecuteOp> {
563561
loc, rewriter.getI64Type(),
564562
rewriter.create<polygeist::TypeSizeOp>(loc, rewriter.getIndexType(),
565563
ST));
566-
mlir::Value alloc = rewriter.create<LLVM::BitcastOp>(
567-
loc, LLVM::LLVMPointerType::get(ST),
568-
callMalloc(rewriter, module, loc, arg));
564+
auto mallocFunc = LLVM::lookupOrCreateMallocFn(module, getIndexType());
565+
mlir::Value alloc =
566+
rewriter.create<LLVM::CallOp>(loc, mallocFunc, arg).getResult();
567+
alloc = rewriter.create<LLVM::BitcastOp>(
568+
loc, LLVM::LLVMPointerType::get(ST), alloc);
569569
rewriter.setInsertionPoint(execute);
570570
for (auto idx : llvm::enumerate(crossing)) {
571571

@@ -584,8 +584,14 @@ struct AsyncOpLowering : public ConvertOpToLLVMPattern<async::ExecuteOp> {
584584
vals.push_back(
585585
rewriter.create<LLVM::AddressOfOp>(execute.getLoc(), func));
586586
for (auto dep : execute.getDependencies()) {
587-
auto ctx = dep.getDefiningOp<polygeist::StreamToTokenOp>();
588-
vals.push_back(ctx.getSource());
587+
auto src = dep.getDefiningOp<polygeist::StreamToTokenOp>().getSource();
588+
if (auto MT = src.getType().dyn_cast<MemRefType>())
589+
src = rewriter.create<polygeist::Memref2PointerOp>(
590+
dep.getDefiningOp()->getLoc(),
591+
LLVM::LLVMPointerType::get(MT.getElementType(),
592+
MT.getMemorySpaceAsInt()),
593+
src);
594+
vals.push_back(src);
589595
}
590596
assert(vals.size() == 3);
591597

@@ -1394,7 +1400,7 @@ struct ConvertPolygeistToLLVMPass
13941400
*/
13951401

13961402
if (i == 1) {
1397-
target.addIllegalOp<UnrealizedConversionCastOp>();
1403+
// target.addIllegalOp<UnrealizedConversionCastOp>();
13981404
patterns.add<AsyncOpLowering>(converter);
13991405
patterns.add<StreamToTokenOpLowering>(converter);
14001406
}

test/polygeist-opt/asynclower.mlir

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
// RUN: polygeist-opt --convert-polygeist-to-llvm --split-input-file %s --allow-unregistered-dialect | FileCheck %s
2+
3+
module {
4+
llvm.func @_Z3runP11CUstream_stPii(%arg0: !llvm.ptr<struct<()>>, %arg1: !llvm.ptr<i32>, %arg2: i32) {
5+
%0 = llvm.mlir.constant(0 : index) : i64
6+
%1 = llvm.mlir.constant(1 : index) : i64
7+
%2 = llvm.mlir.constant(20 : index) : i64
8+
%3 = llvm.mlir.constant(10 : index) : i64
9+
%4 = llvm.bitcast %arg0 : !llvm.ptr<struct<()>> to !llvm.ptr<i8>
10+
%5 = llvm.bitcast %4 : !llvm.ptr<i8> to !llvm.ptr<i8>
11+
%6 = builtin.unrealized_conversion_cast %5 : !llvm.ptr<i8> to memref<?xi8>
12+
%7 = "polygeist.stream2token"(%6) : (memref<?xi8>) -> !async.token
13+
%token = async.execute [%7] {
14+
omp.parallel {
15+
omp.wsloop for (%arg3, %arg4) : i64 = (%0, %0) to (%3, %2) step (%1, %1) {
16+
llvm.call @_Z9somethingPii(%arg1, %arg2) : (!llvm.ptr<i32>, i32) -> ()
17+
omp.yield
18+
}
19+
omp.terminator
20+
}
21+
async.yield
22+
}
23+
llvm.return
24+
}
25+
llvm.func @_Z9somethingPii(!llvm.ptr<i32>, i32) attributes {sym_visibility = "private"}
26+
}
27+
28+
// CHECK-LABEL: llvm.func @_Z3runP11CUstream_stPii(
29+
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<struct<()>>,
30+
// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr<i32>,
31+
// CHECK-SAME: %[[VAL_2:.*]]: i32) {
32+
// CHECK-NEXT: %[[VAL_3:.*]] = llvm.mlir.constant(0 : index) : i64
33+
// CHECK-NEXT: %[[VAL_4:.*]] = llvm.mlir.constant(1 : index) : i64
34+
// CHECK-NEXT: %[[VAL_5:.*]] = llvm.mlir.constant(20 : index) : i64
35+
// CHECK-NEXT: %[[VAL_6:.*]] = llvm.mlir.constant(10 : index) : i64
36+
// CHECK-NEXT: %[[VAL_7:.*]] = llvm.bitcast %[[VAL_0]] : !llvm.ptr<struct<()>> to !llvm.ptr<i8>
37+
// CHECK-NEXT: %[[VAL_8:.*]] = llvm.bitcast %[[VAL_7]] : !llvm.ptr<i8> to !llvm.ptr<i8>
38+
// CHECK-NEXT: %[[VAL_9:.*]] = builtin.unrealized_conversion_cast %[[VAL_8]] : !llvm.ptr<i8> to memref<?xi8>
39+
// CHECK-NEXT: %[[VAL_10:.*]] = llvm.mlir.constant(16 : i64) : i64
40+
// CHECK-NEXT: %[[VAL_11:.*]] = llvm.call @malloc(%[[VAL_10]]) : (i64) -> !llvm.ptr<i8>
41+
// CHECK-NEXT: %[[VAL_12:.*]] = llvm.bitcast %[[VAL_11]] : !llvm.ptr<i8> to !llvm.ptr<struct<(ptr<i32>, i32)>>
42+
// CHECK-NEXT: %[[VAL_13:.*]] = llvm.mlir.constant(0 : i32) : i32
43+
// CHECK-NEXT: %[[VAL_14:.*]] = llvm.mlir.constant(0 : i32) : i32
44+
// CHECK-NEXT: %[[VAL_15:.*]] = llvm.getelementptr %[[VAL_12]]{{\[}}%[[VAL_13]], 0] : (!llvm.ptr<struct<(ptr<i32>, i32)>>, i32) -> !llvm.ptr<ptr<i32>>
45+
// CHECK-NEXT: llvm.store %[[VAL_1]], %[[VAL_15]] : !llvm.ptr<ptr<i32>>
46+
// CHECK-NEXT: %[[VAL_16:.*]] = llvm.mlir.constant(0 : i32) : i32
47+
// CHECK-NEXT: %[[VAL_17:.*]] = llvm.mlir.constant(1 : i32) : i32
48+
// CHECK-NEXT: %[[VAL_18:.*]] = llvm.getelementptr %[[VAL_12]]{{\[}}%[[VAL_16]], 1] : (!llvm.ptr<struct<(ptr<i32>, i32)>>, i32) -> !llvm.ptr<i32>
49+
// CHECK-NEXT: llvm.store %[[VAL_2]], %[[VAL_18]] : !llvm.ptr<i32>
50+
// CHECK-NEXT: %[[VAL_19:.*]] = llvm.bitcast %[[VAL_12]] : !llvm.ptr<struct<(ptr<i32>, i32)>> to !llvm.ptr<i8>
51+
// CHECK-NEXT: %[[VAL_20:.*]] = llvm.mlir.addressof @kernelbody.{{[0-9\.]+}} : !llvm.ptr<func<void (ptr<i8>)>>
52+
// CHECK-NEXT: %[[VAL_21:.*]] = llvm.bitcast %[[VAL_8]] : !llvm.ptr<i8> to !llvm.ptr<i8>
53+
// CHECK-NEXT: llvm.call @fake_cuda_dispatch(%[[VAL_19]], %[[VAL_20]], %[[VAL_21]]) : (!llvm.ptr<i8>, !llvm.ptr<func<void (ptr<i8>)>>, !llvm.ptr<i8>) -> ()
54+
// CHECK-NEXT: llvm.return
55+
56+
// CHECK-LABEL: llvm.func @kernelbody.{{[0-9\.]+}}(
57+
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<i8>) {
58+
// CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(0 : index) : i64
59+
// CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.constant(10 : index) : i64
60+
// CHECK-NEXT: %[[VAL_3:.*]] = llvm.mlir.constant(20 : index) : i64
61+
// CHECK-NEXT: %[[VAL_4:.*]] = llvm.mlir.constant(1 : index) : i64
62+
// CHECK-NEXT: %[[VAL_5:.*]] = llvm.bitcast %[[VAL_0]] : !llvm.ptr<i8> to !llvm.ptr<struct<(ptr<i32>, i32)>>
63+
// CHECK-NEXT: %[[VAL_6:.*]] = llvm.mlir.constant(0 : i32) : i32
64+
// CHECK-NEXT: %[[VAL_7:.*]] = llvm.mlir.constant(0 : i32) : i32
65+
// CHECK-NEXT: %[[VAL_8:.*]] = llvm.getelementptr %[[VAL_5]]{{\[}}%[[VAL_6]], 0] : (!llvm.ptr<struct<(ptr<i32>, i32)>>, i32) -> !llvm.ptr<ptr<i32>>
66+
// CHECK-NEXT: %[[VAL_9:.*]] = llvm.load %[[VAL_8]] : !llvm.ptr<ptr<i32>>
67+
// CHECK-NEXT: %[[VAL_10:.*]] = llvm.mlir.constant(0 : i32) : i32
68+
// CHECK-NEXT: %[[VAL_11:.*]] = llvm.mlir.constant(1 : i32) : i32
69+
// CHECK-NEXT: %[[VAL_12:.*]] = llvm.getelementptr %[[VAL_5]]{{\[}}%[[VAL_10]], 1] : (!llvm.ptr<struct<(ptr<i32>, i32)>>, i32) -> !llvm.ptr<i32>
70+
// CHECK-NEXT: %[[VAL_13:.*]] = llvm.load %[[VAL_12]] : !llvm.ptr<i32>
71+
// CHECK-NEXT: llvm.call @free(%[[VAL_0]]) : (!llvm.ptr<i8>) -> ()
72+
// CHECK-NEXT: omp.parallel {
73+
// CHECK-NEXT: omp.wsloop for (%[[VAL_14:.*]], %[[VAL_15:.*]]) : i64 = (%[[VAL_1]], %[[VAL_1]]) to (%[[VAL_2]], %[[VAL_3]]) step (%[[VAL_4]], %[[VAL_4]]) {
74+
// CHECK-NEXT: llvm.call @_Z9somethingPii(%[[VAL_9]], %[[VAL_13]]) : (!llvm.ptr<i32>, i32) -> ()
75+
// CHECK-NEXT: omp.yield
76+
// CHECK-NEXT: }
77+
// CHECK-NEXT: omp.terminator
78+
// CHECK-NEXT: }
79+
// CHECK-NEXT: llvm.return
80+

0 commit comments

Comments
 (0)