diff --git a/lib/polygeist/Dialect.cpp b/lib/polygeist/Dialect.cpp index 692a3dfbc87b..242edca4a64e 100644 --- a/lib/polygeist/Dialect.cpp +++ b/lib/polygeist/Dialect.cpp @@ -8,11 +8,46 @@ #include "polygeist/Dialect.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/Transforms/InliningUtils.h" #include "polygeist/Ops.h" using namespace mlir; using namespace mlir::polygeist; +//===----------------------------------------------------------------------===// +// PolygeistDialect Interfaces +//===----------------------------------------------------------------------===// + +namespace { +/// This class defines the interface for handling inlining with polygeist +/// operations. +struct PolygeistInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + //===--------------------------------------------------------------------===// + // Analysis Hooks + //===--------------------------------------------------------------------===// + + /// Returns true if the given region 'src' can be inlined into the region + /// 'dest' that is attached to an operation registered to the current dialect. + /// 'wouldBeCloned' is set if the region is cloned into its new location + /// rather than moved, indicating there may be other users. + bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, + IRMapping &valueMapping) const final { + return true; + } + + /// Returns true if the given operation 'op', that is registered to this + /// dialect, can be inlined into the given region, false otherwise. + bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned, + IRMapping &valueMapping) const final { + return true; + } + + /// Polygeist regions should be analyzed recursively. + bool shouldAnalyzeRecursively(Operation *op) const final { return true; } +}; +} // namespace //===----------------------------------------------------------------------===// // Polygeist dialect. //===----------------------------------------------------------------------===// @@ -22,6 +57,7 @@ void PolygeistDialect::initialize() { #define GET_OP_LIST #include "polygeist/PolygeistOps.cpp.inc" >(); + addInterfaces(); } #include "polygeist/PolygeistOpsDialect.cpp.inc" diff --git a/test/polygeist-opt/inline.mlir b/test/polygeist-opt/inline.mlir new file mode 100644 index 000000000000..3be1e8043f9d --- /dev/null +++ b/test/polygeist-opt/inline.mlir @@ -0,0 +1,71 @@ +// RUN: polygeist-opt --inline --allow-unregistered-dialect %s | FileCheck %s + +module { + func.func @eval_cost(%arg0: !llvm.ptr, %arg1: i1) -> memref<1x!llvm.struct<(f32, array<3 x f32>)>> { + %1 = "polygeist.pointer2memref"(%arg0) : (!llvm.ptr) -> memref<1x!llvm.struct<(f32, array<3 x f32>)>> + return %1 : memref<1x!llvm.struct<(f32, array<3 x f32>)>> + } + func.func @eval_res(%arg0: memref<1x!llvm.struct<(f32, array<3 x f32>)>>) -> !llvm.struct<(f32, array<3 x f32>)> + { + %c0_i1 = arith.constant 0 : i1 + %alloca = memref.alloca() : memref<1x!llvm.struct<(f32, array<3 x f32>)>> + %0 = "polygeist.memref2pointer"(%alloca) : (memref<1x!llvm.struct<(f32, array<3 x f32>)>>) -> !llvm.ptr + %1 = func.call @eval_cost(%0, %c0_i1) : (!llvm.ptr, i1) -> memref<1x!llvm.struct<(f32, array<3 x f32>)>> + %2 = affine.load %1[0] : memref<1x!llvm.struct<(f32, array<3 x f32>)>> + return %2 : !llvm.struct<(f32, array<3 x f32>)> + } + +// CHECK-LABEL: func.func @eval_cost( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[VAL_1:.*]]: i1) -> memref<1x!llvm.struct<(f32, array<3 x f32>)>> { +// CHECK: %[[VAL_2:.*]] = "polygeist.pointer2memref"(%[[VAL_0]]) : (!llvm.ptr) -> memref<1x!llvm.struct<(f32, array<3 x f32>)>> +// CHECK: return %[[VAL_2]] : memref<1x!llvm.struct<(f32, array<3 x f32>)>> +// CHECK: } + +// CHECK-LABEL: func.func @eval_res( +// CHECK-SAME: %[[VAL_0:.*]]: memref<1x!llvm.struct<(f32, array<3 x f32>)>>) -> !llvm.struct<(f32, array<3 x f32>)> { +// CHECK: %[[VAL_1:.*]] = memref.alloca() : memref)>> +// CHECK: %[[VAL_2:.*]] = affine.load %[[VAL_1]][] : memref)>> +// CHECK: return %[[VAL_2]] : !llvm.struct<(f32, array<3 x f32>)> +// CHECK: } + + func.func private @use(%arg0: index, %arg1: index) -> index{ + %0 = arith.addi %arg0, %arg1 : index + return %0 : index + } + + func.func @f1(%gd : index, %bd : index) { + %mc0 = arith.constant 0 : index + %mc4 = arith.constant 4 : index + %mc1024 = arith.constant 1024 : index + %err = "polygeist.gpu_wrapper"() ({ + affine.parallel (%a1, %a2, %a3) = (0, 0, 0) to (%gd, %mc4, %bd) { + "polygeist.noop"(%a3, %mc0, %mc0) {polygeist.noop_type="gpu_kernel.thread_only"} : (index, index, index) -> () + %a1r = func.call @use(%a1,%mc4) : (index, index) -> (index) + %a2r = func.call @use(%a2,%a1r) : (index, index) -> (index) + %a3r = func.call @use(%a3,%a2r) : (index, index) -> (index) + "test.something"(%a3r) : (index) -> () + } + "polygeist.polygeist_yield"() : () -> () + }) : () -> index + return + } +// CHECK-LABEL: func.func @f1( +// CHECK-SAME: %[[VAL_0:.*]]: index, +// CHECK-SAME: %[[VAL_1:.*]]: index) { +// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_3:.*]] = arith.constant 4 : index +// CHECK: %[[VAL_4:.*]] = "polygeist.gpu_wrapper"() ({ +// CHECK: affine.parallel (%[[VAL_5:.*]], %[[VAL_6:.*]], %[[VAL_7:.*]]) = (0, 0, 0) to (symbol(%[[VAL_0]]), 4, symbol(%[[VAL_1]])) { +// CHECK: "polygeist.noop"(%[[VAL_7]], %[[VAL_2]], %[[VAL_2]]) {polygeist.noop_type = "gpu_kernel.thread_only"} : (index, index, index) -> () +// CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_5]], %[[VAL_3]] : index +// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_6]], %[[VAL_8]] : index +// CHECK: %[[VAL_10:.*]] = arith.addi %[[VAL_7]], %[[VAL_9]] : index +// CHECK: "test.something"(%[[VAL_10]]) : (index) -> () +// CHECK: } +// CHECK: "polygeist.polygeist_yield"() : () -> () +// CHECK: }) : () -> index +// CHECK: return +// CHECK: } + +} \ No newline at end of file diff --git a/tools/cgeist/Lib/clang-mlir.cc b/tools/cgeist/Lib/clang-mlir.cc index 058464c323ef..51ff2b47c098 100644 --- a/tools/cgeist/Lib/clang-mlir.cc +++ b/tools/cgeist/Lib/clang-mlir.cc @@ -2323,10 +2323,10 @@ ValueCategory MLIRScanner::VisitUnaryOperator(clang::UnaryOperator *U) { builder.create(loc, 1, ty.cast())); } sub.store(loc, builder, next); - return ValueCategory( - (U->getOpcode() == clang::UnaryOperator::Opcode::UO_PostDec) ? prev - : next, - /*isReference*/ false); + if (U->getOpcode() == clang::UnaryOperator::Opcode::UO_PreDec) + return sub; + else + return ValueCategory(prev, /*isReference*/ false); } case clang::UnaryOperator::Opcode::UO_Real: case clang::UnaryOperator::Opcode::UO_Imag: { diff --git a/tools/cgeist/Test/Verification/arrayconsllvm.cpp b/tools/cgeist/Test/Verification/arrayconsllvm.cpp index a5e47a7b0cfe..4a2f98698b83 100644 --- a/tools/cgeist/Test/Verification/arrayconsllvm.cpp +++ b/tools/cgeist/Test/Verification/arrayconsllvm.cpp @@ -1,4 +1,4 @@ -// RUN: cgeist %s --function=* -S | FileCheck %s +// RUN: cgeist %s --no-inline --function=* -S | FileCheck %s struct AIntDivider { AIntDivider() : divisor(3) {} diff --git a/tools/cgeist/Test/Verification/arrayconsmemrefinner.cpp b/tools/cgeist/Test/Verification/arrayconsmemrefinner.cpp index 52332ec0c742..bb1416652d3a 100644 --- a/tools/cgeist/Test/Verification/arrayconsmemrefinner.cpp +++ b/tools/cgeist/Test/Verification/arrayconsmemrefinner.cpp @@ -1,4 +1,4 @@ -// RUN: cgeist %s --function=* -S | FileCheck %s +// RUN: cgeist %s --no-inline --function=* -S | FileCheck %s struct AIntDivider { AIntDivider() : divisor(3) {} diff --git a/tools/cgeist/Test/Verification/base_cast.cpp b/tools/cgeist/Test/Verification/base_cast.cpp index 2ae3651b16be..68174e6ec831 100644 --- a/tools/cgeist/Test/Verification/base_cast.cpp +++ b/tools/cgeist/Test/Verification/base_cast.cpp @@ -1,5 +1,5 @@ -// RUN: cgeist %s --function=* -S | FileCheck %s -// RUN: cgeist %s --function=* --struct-abi=0 -memref-abi=0 -S | FileCheck %s --check-prefix CHECK-STR +// RUN: cgeist %s --no-inline --function=* -S | FileCheck %s +// RUN: cgeist %s --no-inline --function=* --struct-abi=0 -memref-abi=0 -S | FileCheck %s --check-prefix CHECK-STR struct A { diff --git a/tools/cgeist/Test/Verification/base_nostructabi.cpp b/tools/cgeist/Test/Verification/base_nostructabi.cpp index 735693f6536d..77efbf91f83b 100644 --- a/tools/cgeist/Test/Verification/base_nostructabi.cpp +++ b/tools/cgeist/Test/Verification/base_nostructabi.cpp @@ -1,4 +1,4 @@ -// RUN: cgeist %s --function=* --struct-abi=0 -memref-abi=0 -S | FileCheck %s +// RUN: cgeist %s --no-inline --function=* --struct-abi=0 -memref-abi=0 -S | FileCheck %s void run0(void*); void run1(void*); diff --git a/tools/cgeist/Test/Verification/base_with_virt.cpp b/tools/cgeist/Test/Verification/base_with_virt.cpp index 3d4f8f188027..0694213f0680 100644 --- a/tools/cgeist/Test/Verification/base_with_virt.cpp +++ b/tools/cgeist/Test/Verification/base_with_virt.cpp @@ -1,4 +1,4 @@ -// RUN: cgeist %s --function=* -S | FileCheck %s +// RUN: cgeist %s --no-inline --function=* -S | FileCheck %s class M { }; diff --git a/tools/cgeist/Test/Verification/caff.cpp b/tools/cgeist/Test/Verification/caff.cpp index 6a5dc5c426d8..7c9a5df4474f 100644 --- a/tools/cgeist/Test/Verification/caff.cpp +++ b/tools/cgeist/Test/Verification/caff.cpp @@ -1,4 +1,4 @@ -// RUN: cgeist %s --function=* -S | FileCheck %s +// RUN: cgeist %s --no-inline --function=* -S | FileCheck %s struct AOperandInfo { void* data; diff --git a/tools/cgeist/Test/Verification/capture.cpp b/tools/cgeist/Test/Verification/capture.cpp index 159c5a1fc87b..3007c907ae35 100644 --- a/tools/cgeist/Test/Verification/capture.cpp +++ b/tools/cgeist/Test/Verification/capture.cpp @@ -1,4 +1,4 @@ -// RUN: cgeist %s --function=* -S | FileCheck %s +// RUN: cgeist %s --no-inline --function=* -S | FileCheck %s extern "C" { diff --git a/tools/cgeist/Test/Verification/consabi.cpp b/tools/cgeist/Test/Verification/consabi.cpp index 3b50994b2baf..26e27cd5b152 100644 --- a/tools/cgeist/Test/Verification/consabi.cpp +++ b/tools/cgeist/Test/Verification/consabi.cpp @@ -15,42 +15,50 @@ QStream ilaunch_kernel(QStream x) { } // CHECK-LABEL: func.func @_Z14ilaunch_kernel7QStream( -// CHECK-SAME: %[[VAL_0:[A-Za-z0-9_]*]]: !llvm.struct<(struct<(f64, f64)>, i32)>) -> !llvm.struct<(struct<(f64, f64)>, i32)> -// CHECK: %[[VAL_1:[A-Za-z0-9_]*]] = memref.alloca() : memref<1x!llvm.struct<(struct<(f64, f64)>, i32)>> -// CHECK: %[[VAL_2:[A-Za-z0-9_]*]] = memref.cast %[[VAL_1]] : memref<1x!llvm.struct<(struct<(f64, f64)>, i32)>> to memref, i32)>> -// CHECK: %[[VAL_3:[A-Za-z0-9_]*]] = memref.alloca() : memref<1x!llvm.struct<(struct<(f64, f64)>, i32)>> -// CHECK: %[[VAL_4:[A-Za-z0-9_]*]] = memref.cast %[[VAL_3]] : memref<1x!llvm.struct<(struct<(f64, f64)>, i32)>> to memref, i32)>> -// CHECK: affine.store %[[VAL_0]], %[[VAL_3]][0] : memref<1x!llvm.struct<(struct<(f64, f64)>, i32)>> -// CHECK: call @_ZN7QStreamC1EOS_(%[[VAL_2]], %[[VAL_4]]) : (memref, i32)>>, memref, i32)>>) -> () -// CHECK: %[[VAL_5:[A-Za-z0-9_]*]] = affine.load %[[VAL_1]][0] : memref<1x!llvm.struct<(struct<(f64, f64)>, i32)>> -// CHECK: return %[[VAL_5]] : !llvm.struct<(struct<(f64, f64)>, i32)> +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(struct<(f64, f64)>, i32)>) -> !llvm.struct<(struct<(f64, f64)>, i32)> attributes {llvm.linkage = #llvm.linkage} { +// CHECK: %[[VAL_1:.*]] = memref.alloca() : memref<1x!llvm.struct<(struct<(f64, f64)>, i32)>> +// CHECK: %[[VAL_2:.*]] = memref.alloca() : memref<1x!llvm.struct<(struct<(f64, f64)>, i32)>> +// CHECK: affine.store %[[VAL_0]], %[[VAL_2]][0] : memref<1x!llvm.struct<(struct<(f64, f64)>, i32)>> +// CHECK: %[[VAL_3:.*]] = "polygeist.memref2pointer"(%[[VAL_1]]) : (memref<1x!llvm.struct<(struct<(f64, f64)>, i32)>>) -> !llvm.ptr +// CHECK: %[[VAL_4:.*]] = "polygeist.memref2pointer"(%[[VAL_2]]) : (memref<1x!llvm.struct<(struct<(f64, f64)>, i32)>>) -> !llvm.ptr +// CHECK: %[[VAL_5:.*]] = llvm.load %[[VAL_4]] : !llvm.ptr -> f64 +// CHECK: llvm.store %[[VAL_5]], %[[VAL_3]] : f64, !llvm.ptr +// CHECK: %[[VAL_6:.*]] = llvm.getelementptr %[[VAL_4]][1] : (!llvm.ptr) -> !llvm.ptr, f64 +// CHECK: %[[VAL_7:.*]] = llvm.load %[[VAL_6]] : !llvm.ptr -> f64 +// CHECK: %[[VAL_8:.*]] = llvm.getelementptr %[[VAL_3]][1] : (!llvm.ptr) -> !llvm.ptr, f64 +// CHECK: llvm.store %[[VAL_7]], %[[VAL_8]] : f64, !llvm.ptr +// CHECK: %[[VAL_9:.*]] = llvm.getelementptr %[[VAL_4]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(struct<(f64, f64)>, i32)> +// CHECK: %[[VAL_10:.*]] = llvm.load %[[VAL_9]] : !llvm.ptr -> i32 +// CHECK: %[[VAL_11:.*]] = llvm.getelementptr %[[VAL_3]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(struct<(f64, f64)>, i32)> +// CHECK: llvm.store %[[VAL_10]], %[[VAL_11]] : i32, !llvm.ptr +// CHECK: %[[VAL_12:.*]] = affine.load %[[VAL_1]][0] : memref<1x!llvm.struct<(struct<(f64, f64)>, i32)>> +// CHECK: return %[[VAL_12]] : !llvm.struct<(struct<(f64, f64)>, i32)> // CHECK: } // CHECK-LABEL: func.func @_ZN7QStreamC1EOS_( -// CHECK-SAME: %[[VAL_0:[A-Za-z0-9_]*]]: memref, i32)>>, -// CHECK-SAME: %[[VAL_1:[A-Za-z0-9_]*]]: memref, i32)>>) -// CHECK: %[[VAL_2:[A-Za-z0-9_]*]] = "polygeist.memref2pointer"(%[[VAL_0]]) : (memref, i32)>>) -> !llvm.ptr -// CHECK: %[[VAL_3:[A-Za-z0-9_]*]] = "polygeist.memref2pointer"(%[[VAL_1]]) : (memref, i32)>>) -> !llvm.ptr -// CHECK: %[[VAL_4:[A-Za-z0-9_]*]] = llvm.load %[[VAL_3]] : !llvm.ptr -> f64 +// CHECK-SAME: %[[VAL_0:.*]]: memref, i32)>>, +// CHECK-SAME: %[[VAL_1:.*]]: memref, i32)>>) attributes {llvm.linkage = #llvm.linkage} { +// CHECK: %[[VAL_2:.*]] = "polygeist.memref2pointer"(%[[VAL_0]]) : (memref, i32)>>) -> !llvm.ptr +// CHECK: %[[VAL_3:.*]] = "polygeist.memref2pointer"(%[[VAL_1]]) : (memref, i32)>>) -> !llvm.ptr +// CHECK: %[[VAL_4:.*]] = llvm.load %[[VAL_3]] : !llvm.ptr -> f64 // CHECK: llvm.store %[[VAL_4]], %[[VAL_2]] : f64, !llvm.ptr -// CHECK: %[[VAL_5:[A-Za-z0-9_]*]] = llvm.getelementptr %[[VAL_3]][1] : (!llvm.ptr) -> !llvm.ptr, f64 -// CHECK: %[[VAL_6:[A-Za-z0-9_]*]] = llvm.load %[[VAL_5]] : !llvm.ptr -> f64 -// CHECK: %[[VAL_7:[A-Za-z0-9_]*]] = llvm.getelementptr %[[VAL_2]][1] : (!llvm.ptr) -> !llvm.ptr, f64 +// CHECK: %[[VAL_5:.*]] = llvm.getelementptr %[[VAL_3]][1] : (!llvm.ptr) -> !llvm.ptr, f64 +// CHECK: %[[VAL_6:.*]] = llvm.load %[[VAL_5]] : !llvm.ptr -> f64 +// CHECK: %[[VAL_7:.*]] = llvm.getelementptr %[[VAL_2]][1] : (!llvm.ptr) -> !llvm.ptr, f64 // CHECK: llvm.store %[[VAL_6]], %[[VAL_7]] : f64, !llvm.ptr -// CHECK: %[[VAL_8:[A-Za-z0-9_]*]] = llvm.getelementptr %[[VAL_3]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(struct<(f64, f64)>, i32)> -// CHECK: %[[VAL_9:[A-Za-z0-9_]*]] = llvm.load %[[VAL_8]] : !llvm.ptr -> i32 -// CHECK: %[[VAL_10:[A-Za-z0-9_]*]] = llvm.getelementptr %[[VAL_2]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(struct<(f64, f64)>, i32)> +// CHECK: %[[VAL_8:.*]] = llvm.getelementptr %[[VAL_3]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(struct<(f64, f64)>, i32)> +// CHECK: %[[VAL_9:.*]] = llvm.load %[[VAL_8]] : !llvm.ptr -> i32 +// CHECK: %[[VAL_10:.*]] = llvm.getelementptr %[[VAL_2]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(struct<(f64, f64)>, i32)> // CHECK: llvm.store %[[VAL_9]], %[[VAL_10]] : i32, !llvm.ptr // CHECK: return // CHECK: } // CHECK-LABEL: func.func @_ZN1DC1EOS_( -// CHECK-SAME: %[[VAL_0:[A-Za-z0-9_]*]]: memref, -// CHECK-SAME: %[[VAL_1:[A-Za-z0-9_]*]]: memref) -// CHECK: %[[VAL_2:[A-Za-z0-9_]*]] = affine.load %[[VAL_1]][0, 0] : memref +// CHECK-SAME: %[[VAL_0:.*]]: memref, +// CHECK-SAME: %[[VAL_1:.*]]: memref) attributes {llvm.linkage = #llvm.linkage} { +// CHECK: %[[VAL_2:.*]] = affine.load %[[VAL_1]][0, 0] : memref // CHECK: affine.store %[[VAL_2]], %[[VAL_0]][0, 0] : memref -// CHECK: %[[VAL_3:[A-Za-z0-9_]*]] = affine.load %[[VAL_1]][0, 1] : memref +// CHECK: %[[VAL_3:.*]] = affine.load %[[VAL_1]][0, 1] : memref // CHECK: affine.store %[[VAL_3]], %[[VAL_0]][0, 1] : memref // CHECK: return -// CHECK: } - +// CHECK: } \ No newline at end of file diff --git a/tools/cgeist/Test/Verification/cugen.cu b/tools/cgeist/Test/Verification/cugen.cu index 40e47219ef63..4c81e95d779f 100644 --- a/tools/cgeist/Test/Verification/cugen.cu +++ b/tools/cgeist/Test/Verification/cugen.cu @@ -23,10 +23,10 @@ void start(double* w) { } // CHECK: func.func @_Z5startPd(%arg0: memref) -// CHECK-NEXT: %cst = arith.constant 2.000000e+00 : f64 -// CHECK-NEXT: %c0 = arith.constant 0 : index -// CHECK-NEXT: %c20 = arith.constant 20 : index -// CHECK-NEXT: %c1 = arith.constant 1 : index +// CHECK-DAG: %cst = arith.constant 2.000000e+00 : f64 +// CHECK-DAG: %c0 = arith.constant 0 : index +// CHECK-DAG: %c20 = arith.constant 20 : index +// CHECK-DAG: %c1 = arith.constant 1 : index // CHECK-NEXT: scf.parallel (%arg1) = (%c0) to (%c20) step (%c1) { // CHECK-NEXT: memref.store %cst, %arg0[%arg1] : memref // CHECK-NEXT: scf.yield diff --git a/tools/cgeist/Test/Verification/ident.cpp b/tools/cgeist/Test/Verification/ident.cpp index 54623b73f757..6eb6594eb267 100644 --- a/tools/cgeist/Test/Verification/ident.cpp +++ b/tools/cgeist/Test/Verification/ident.cpp @@ -35,54 +35,32 @@ void lt_kernel_cuda(MTensorIterator& iter) { } } - // CHECK-LABEL: func.func @lt_kernel_cuda( -// CHECK-SAME: %[[VAL_0:[A-Za-z0-9_]*]]: memref)>)>>) -// CHECK: %[[VAL_1:[A-Za-z0-9_]*]] = arith.constant 0 : i8 -// CHECK: %[[VAL_2:[A-Za-z0-9_]*]] = memref.alloca() : memref<1x1xmemref)>)>>> -// CHECK: %[[VAL_3:[A-Za-z0-9_]*]] = memref.cast %[[VAL_2]] : memref<1x1xmemref)>)>>> to memref)>)>>> -// CHECK: %[[VAL_4:[A-Za-z0-9_]*]] = call @_ZNK15MTensorIterator11input_dtypeEv(%[[VAL_0]]) : (memref)>)>>) -> i8 -// CHECK: %[[VAL_5:[A-Za-z0-9_]*]] = arith.cmpi ne, %[[VAL_4]], %[[VAL_1]] : i8 -// CHECK: scf.if %[[VAL_5]] { -// CHECK: affine.store %[[VAL_0]], %[[VAL_2]][0, 0] : memref<1x1xmemref)>)>>> -// CHECK: func.call @_ZZ14lt_kernel_cudaENK3$_0clEv(%[[VAL_3]]) : (memref)>)>>>) -> () -// CHECK: } +// CHECK-SAME: %[[VAL_0:.*]]: memref)>)>>) attributes {llvm.linkage = #llvm.linkage} { // CHECK: return // CHECK: } // CHECK-LABEL: func.func @_ZNK15MTensorIterator11input_dtypeEv( -// CHECK-SAME: %[[VAL_0:[A-Za-z0-9_]*]]: memref)>)>>) -> i8 -// CHECK: %[[VAL_1:[A-Za-z0-9_]*]] = arith.constant 0 : i32 -// CHECK: %[[VAL_2:[A-Za-z0-9_]*]] = "polygeist.memref2pointer"(%[[VAL_0]]) : (memref)>)>>) -> !llvm.ptr -// CHECK: %[[VAL_3:[A-Za-z0-9_]*]] = "polygeist.pointer2memref"(%[[VAL_2]]) : (!llvm.ptr) -> memref> -// CHECK: %[[VAL_4:[A-Za-z0-9_]*]] = call @_ZNK12MSmallVectorI12MOperandInfoEixEi(%[[VAL_3]], %[[VAL_1]]) : (memref>, i32) -> memref -// CHECK: %[[VAL_5:[A-Za-z0-9_]*]] = affine.load %[[VAL_4]][0, 1] : memref -// CHECK: return %[[VAL_5]] : i8 -// CHECK: } - -// CHECK-LABEL: func.func private @_ZZ14lt_kernel_cudaENK3$_0clEv( -// CHECK-SAME: %[[VAL_0:[A-Za-z0-9_]*]]: memref)>)>>>) -// CHECK: %[[VAL_1:[A-Za-z0-9_]*]] = affine.load %[[VAL_0]][0, 0] : memref)>)>>> -// CHECK: %[[VAL_2:[A-Za-z0-9_]*]] = call @_ZNK15MTensorIterator6deviceEv(%[[VAL_1]]) : (memref)>)>>) -> i8 -// CHECK: return +// CHECK-SAME: %[[VAL_0:.*]]: memref)>)>>) -> i8 attributes {llvm.linkage = #llvm.linkage} { +// CHECK: %[[VAL_1:.*]] = "polygeist.memref2pointer"(%[[VAL_0]]) : (memref)>)>>) -> !llvm.ptr +// CHECK: %[[VAL_2:.*]] = llvm.load %[[VAL_1]] : !llvm.ptr -> memref +// CHECK: %[[VAL_3:.*]] = affine.load %[[VAL_2]][0, 1] : memref +// CHECK: return %[[VAL_3]] : i8 // CHECK: } // CHECK-LABEL: func.func @_ZNK12MSmallVectorI12MOperandInfoEixEi( -// CHECK-SAME: %[[VAL_0:[A-Za-z0-9_]*]]: memref>, -// CHECK-SAME: %[[VAL_1:[A-Za-z0-9_]*]]: i32) -> memref -// CHECK: %[[VAL_2:[A-Za-z0-9_]*]] = affine.load %[[VAL_0]][0, 0] : memref> -// CHECK: %[[VAL_3:[A-Za-z0-9_]*]] = arith.index_cast %[[VAL_1]] : i32 to index -// CHECK: %[[VAL_4:[A-Za-z0-9_]*]] = "polygeist.subindex"(%[[VAL_2]], %[[VAL_3]]) : (memref, index) -> memref +// CHECK-SAME: %[[VAL_0:.*]]: memref>, +// CHECK-SAME: %[[VAL_1:.*]]: i32) -> memref attributes {llvm.linkage = #llvm.linkage} { +// CHECK: %[[VAL_2:.*]] = affine.load %[[VAL_0]][0, 0] : memref> +// CHECK: %[[VAL_3:.*]] = arith.index_cast %[[VAL_1]] : i32 to index +// CHECK: %[[VAL_4:.*]] = "polygeist.subindex"(%[[VAL_2]], %[[VAL_3]]) : (memref, index) -> memref // CHECK: return %[[VAL_4]] : memref // CHECK: } // CHECK-LABEL: func.func @_ZNK15MTensorIterator6deviceEv( -// CHECK-SAME: %[[VAL_0:[A-Za-z0-9_]*]]: memref)>)>>) -> i8 -// CHECK: %[[VAL_1:[A-Za-z0-9_]*]] = arith.constant 0 : i32 -// CHECK: %[[VAL_2:[A-Za-z0-9_]*]] = "polygeist.memref2pointer"(%[[VAL_0]]) : (memref)>)>>) -> !llvm.ptr -// CHECK: %[[VAL_3:[A-Za-z0-9_]*]] = "polygeist.pointer2memref"(%[[VAL_2]]) : (!llvm.ptr) -> memref> -// CHECK: %[[VAL_4:[A-Za-z0-9_]*]] = call @_ZNK12MSmallVectorI12MOperandInfoEixEi(%[[VAL_3]], %[[VAL_1]]) : (memref>, i32) -> memref -// CHECK: %[[VAL_5:[A-Za-z0-9_]*]] = affine.load %[[VAL_4]][0, 0] : memref -// CHECK: return %[[VAL_5]] : i8 -// CHECK: } - +// CHECK-SAME: %[[VAL_0:.*]]: memref)>)>>) -> i8 attributes {llvm.linkage = #llvm.linkage} { +// CHECK: %[[VAL_1:.*]] = "polygeist.memref2pointer"(%[[VAL_0]]) : (memref)>)>>) -> !llvm.ptr +// CHECK: %[[VAL_2:.*]] = llvm.load %[[VAL_1]] : !llvm.ptr -> memref +// CHECK: %[[VAL_3:.*]] = affine.load %[[VAL_2]][0, 0] : memref +// CHECK: return %[[VAL_3]] : i8 +// CHECK: } \ No newline at end of file diff --git a/tools/cgeist/Test/Verification/indirect.c b/tools/cgeist/Test/Verification/indirect.c index eadcd2f8a76b..0c47b22aebde 100644 --- a/tools/cgeist/Test/Verification/indirect.c +++ b/tools/cgeist/Test/Verification/indirect.c @@ -1,5 +1,5 @@ -// RUN: cgeist %s --function=main -S | FileCheck %s -// RUN: cgeist %s --function=main -S --emit-llvm | FileCheck %s --check-prefix=LLCHECK +// RUN: cgeist %s --no-inline --function=main -S | FileCheck %s +// RUN: cgeist %s --no-inline --function=main -S --emit-llvm | FileCheck %s --check-prefix=LLCHECK int square(int x) { return x*x; diff --git a/tools/cgeist/Test/Verification/simpcomplex.cpp b/tools/cgeist/Test/Verification/simpcomplex.cpp index ec2af11438f8..58ea1c459bad 100644 --- a/tools/cgeist/Test/Verification/simpcomplex.cpp +++ b/tools/cgeist/Test/Verification/simpcomplex.cpp @@ -1,4 +1,4 @@ -// RUN: cgeist %s --struct-abi=0 --function='*' -S | FileCheck %s --check-prefix=STRUCT +// RUN: cgeist %s --no-inline --struct-abi=0 --function='*' -S | FileCheck %s --check-prefix=STRUCT // COM: we dont support this yet: cgeist %s --function='*' -S | FileCheck %s void foo() { diff --git a/tools/cgeist/Test/Verification/unioncopy.cpp b/tools/cgeist/Test/Verification/unioncopy.cpp index cfebd9c8ff07..6a957fb5a5ae 100644 --- a/tools/cgeist/Test/Verification/unioncopy.cpp +++ b/tools/cgeist/Test/Verification/unioncopy.cpp @@ -1,4 +1,4 @@ -// RUN: cgeist %s --function=* -S | FileCheck %s +// RUN: cgeist %s --no-inline --function=* -S | FileCheck %s union S { double d; diff --git a/tools/cgeist/Test/Verification/virt.cpp b/tools/cgeist/Test/Verification/virt.cpp index 04e61d06d83b..b79b397b897c 100644 --- a/tools/cgeist/Test/Verification/virt.cpp +++ b/tools/cgeist/Test/Verification/virt.cpp @@ -1,4 +1,4 @@ -// RUN: cgeist %s --function=* -S | FileCheck %s +// RUN: cgeist %s --no-inline --function=* -S | FileCheck %s extern void print(char*); diff --git a/tools/cgeist/Test/Verification/virt2.cpp b/tools/cgeist/Test/Verification/virt2.cpp index ff919d4b1751..abb23a31d204 100644 --- a/tools/cgeist/Test/Verification/virt2.cpp +++ b/tools/cgeist/Test/Verification/virt2.cpp @@ -1,4 +1,4 @@ -// RUN: cgeist %s --function=* -S | FileCheck %s +// RUN: cgeist %s --no-inline --function=* -S | FileCheck %s extern void print(char*); diff --git a/tools/cgeist/driver.cc b/tools/cgeist/driver.cc index 2eee90d1a54c..886e11a78318 100644 --- a/tools/cgeist/driver.cc +++ b/tools/cgeist/driver.cc @@ -136,6 +136,9 @@ static cl::opt PrintDebugInfo("print-debug-info", cl::init(false), static cl::opt EmitAssembly("S", cl::init(false), cl::desc("Emit Assembly")); +static cl::opt NoInline("no-inline", cl::init(false), + cl::desc("Prevent inlining")); + static cl::opt Opt0("O0", cl::init(false), cl::desc("Opt level 0")); static cl::opt Opt1("O1", cl::init(false), cl::desc("Opt level 1")); static cl::opt Opt2("O2", cl::init(false), cl::desc("Opt level 2")); @@ -702,7 +705,8 @@ int main(int argc, char **argv) { optPM.addPass(mlir::createLowerAffinePass()); optPM.addPass(mlir::polygeist::createPolygeistCanonicalizePass( canonicalizerConfig, {}, {})); - pm.addPass(mlir::createInlinerPass()); + if (!NoInline) + pm.addPass(mlir::createInlinerPass()); mlir::OpPassManager &optPM2 = pm.nest(); optPM2.addPass(mlir::polygeist::createPolygeistCanonicalizePass( canonicalizerConfig, {}, {})); @@ -747,7 +751,8 @@ int main(int argc, char **argv) { noptPM.addPass(polygeist::createPolygeistMem2RegPass()); noptPM.addPass(mlir::polygeist::createPolygeistCanonicalizePass( canonicalizerConfig, {}, {})); - pm.addPass(mlir::createInlinerPass()); + if (!NoInline) + pm.addPass(mlir::createInlinerPass()); mlir::OpPassManager &noptPM2 = pm.nest(); noptPM2.addPass(mlir::polygeist::createPolygeistCanonicalizePass( canonicalizerConfig, {}, {})); diff --git a/tools/polygeist-opt/polygeist-opt.cpp b/tools/polygeist-opt/polygeist-opt.cpp index 95fe1b1fc4a4..fb53326e7cf6 100644 --- a/tools/polygeist-opt/polygeist-opt.cpp +++ b/tools/polygeist-opt/polygeist-opt.cpp @@ -70,6 +70,7 @@ int main(int argc, char **argv) { mlir::registerConvertAffineToStandardPass(); mlir::registerSCCPPass(); mlir::registerInlinerPass(); + mlir::registerSROAPass(); mlir::registerCanonicalizerPass(); mlir::registerSymbolDCEPass(); mlir::registerLoopInvariantCodeMotionPass();