diff --git a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h index b595b6a308bea..5abfb3d7e72dd 100644 --- a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h +++ b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h @@ -10,8 +10,11 @@ constexpr const char *alignedAllocFunctionName = "aligned_alloc"; constexpr const char *mallocFunctionName = "malloc"; +constexpr const char *memcpyFunctionName = "memcpy"; constexpr const char *cppStandardLibraryHeader = "cstdlib"; constexpr const char *cStandardLibraryHeader = "stdlib.h"; +constexpr const char *cppStringLibraryHeader = "cstring"; +constexpr const char *cStringLibraryHeader = "string.h"; namespace mlir { class DialectRegistry; diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index 6bd0e2d4d4b08..a1f38c95935ad 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -17,11 +17,13 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeRange.h" #include "mlir/IR/Value.h" #include "mlir/Transforms/DialectConversion.h" #include +#include using namespace mlir; @@ -97,6 +99,48 @@ Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) { return resultTy; } +static Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType, + OpBuilder &builder) { + assert(isMemRefTypeLegalForEmitC(memrefType) && + "incompatible memref type for EmitC conversion"); + emitc::CallOpaqueOp elementSize = emitc::CallOpaqueOp::create( + builder, loc, emitc::SizeTType::get(builder.getContext()), + builder.getStringAttr("sizeof"), ValueRange{}, + ArrayAttr::get(builder.getContext(), + {TypeAttr::get(memrefType.getElementType())})); + + IndexType indexType = builder.getIndexType(); + int64_t numElements = std::accumulate(memrefType.getShape().begin(), + memrefType.getShape().end(), int64_t{1}, + std::multiplies()); + emitc::ConstantOp numElementsValue = emitc::ConstantOp::create( + builder, loc, indexType, builder.getIndexAttr(numElements)); + + Type sizeTType = emitc::SizeTType::get(builder.getContext()); + emitc::MulOp totalSizeBytes = emitc::MulOp::create( + builder, loc, sizeTType, elementSize.getResult(0), numElementsValue); + + return totalSizeBytes.getResult(); +} + +static emitc::ApplyOp +createPointerFromEmitcArray(Location loc, OpBuilder &builder, + TypedValue arrayValue) { + + emitc::ConstantOp zeroIndex = emitc::ConstantOp::create( + builder, loc, builder.getIndexType(), builder.getIndexAttr(0)); + + emitc::ArrayType arrayType = arrayValue.getType(); + llvm::SmallVector indices(arrayType.getRank(), zeroIndex); + emitc::SubscriptOp subPtr = + emitc::SubscriptOp::create(builder, loc, arrayValue, ValueRange(indices)); + emitc::ApplyOp ptr = emitc::ApplyOp::create( + builder, loc, emitc::PointerType::get(arrayType.getElementType()), + builder.getStringAttr("&"), subPtr); + + return ptr; +} + struct ConvertAlloc final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -159,6 +203,47 @@ struct ConvertAlloc final : public OpConversionPattern { } }; +struct ConvertCopy final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::CopyOp copyOp, OpAdaptor operands, + ConversionPatternRewriter &rewriter) const override { + Location loc = copyOp.getLoc(); + MemRefType srcMemrefType = cast(copyOp.getSource().getType()); + MemRefType targetMemrefType = + cast(copyOp.getTarget().getType()); + + if (!isMemRefTypeLegalForEmitC(srcMemrefType)) + return rewriter.notifyMatchFailure( + loc, "incompatible source memref type for EmitC conversion"); + + if (!isMemRefTypeLegalForEmitC(targetMemrefType)) + return rewriter.notifyMatchFailure( + loc, "incompatible target memref type for EmitC conversion"); + + auto srcArrayValue = + cast>(operands.getSource()); + emitc::ApplyOp srcPtr = + createPointerFromEmitcArray(loc, rewriter, srcArrayValue); + + auto targetArrayValue = + cast>(operands.getTarget()); + emitc::ApplyOp targetPtr = + createPointerFromEmitcArray(loc, rewriter, targetArrayValue); + + emitc::CallOpaqueOp memCpyCall = emitc::CallOpaqueOp::create( + rewriter, loc, TypeRange{}, "memcpy", + ValueRange{ + targetPtr.getResult(), srcPtr.getResult(), + calculateMemrefTotalSizeBytes(loc, srcMemrefType, rewriter)}); + + rewriter.replaceOp(copyOp, memCpyCall.getResults()); + + return success(); + } +}; + struct ConvertGlobal final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -320,6 +405,7 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { void mlir::populateMemRefToEmitCConversionPatterns( RewritePatternSet &patterns, const TypeConverter &converter) { - patterns.add(converter, patterns.getContext()); + patterns.add( + converter, patterns.getContext()); } diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp index e78dd76d6e256..a51890248271f 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp @@ -18,6 +18,8 @@ #include "mlir/IR/Attributes.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/StringRef.h" namespace mlir { #define GEN_PASS_DEF_CONVERTMEMREFTOEMITC @@ -27,6 +29,15 @@ namespace mlir { using namespace mlir; namespace { + +emitc::IncludeOp addStandardHeader(OpBuilder &builder, ModuleOp module, + StringRef headerName) { + StringAttr includeAttr = builder.getStringAttr(headerName); + return builder.create( + module.getLoc(), includeAttr, + /*is_standard_include=*/builder.getUnitAttr()); +} + struct ConvertMemRefToEmitCPass : public impl::ConvertMemRefToEmitCBase { using Base::Base; @@ -55,34 +66,29 @@ struct ConvertMemRefToEmitCPass return signalPassFailure(); mlir::ModuleOp module = getOperation(); + llvm::SmallSet existingHeaders; + mlir::OpBuilder builder(module.getBody(), module.getBody()->begin()); + module.walk([&](mlir::emitc::IncludeOp includeOp) { + if (includeOp.getIsStandardInclude()) + existingHeaders.insert(includeOp.getInclude()); + }); + module.walk([&](mlir::emitc::CallOpaqueOp callOp) { - if (callOp.getCallee() != alignedAllocFunctionName && - callOp.getCallee() != mallocFunctionName) { + StringRef expectedHeader; + if (callOp.getCallee() == alignedAllocFunctionName || + callOp.getCallee() == mallocFunctionName) + expectedHeader = options.lowerToCpp ? cppStandardLibraryHeader + : cStandardLibraryHeader; + else if (callOp.getCallee() == memcpyFunctionName) + expectedHeader = + options.lowerToCpp ? cppStringLibraryHeader : cStringLibraryHeader; + else return mlir::WalkResult::advance(); + if (!existingHeaders.contains(expectedHeader)) { + addStandardHeader(builder, module, expectedHeader); + existingHeaders.insert(expectedHeader); } - - for (auto &op : *module.getBody()) { - emitc::IncludeOp includeOp = llvm::dyn_cast(op); - if (!includeOp) { - continue; - } - if (includeOp.getIsStandardInclude() && - ((options.lowerToCpp && - includeOp.getInclude() == cppStandardLibraryHeader) || - (!options.lowerToCpp && - includeOp.getInclude() == cStandardLibraryHeader))) { - return mlir::WalkResult::interrupt(); - } - } - - mlir::OpBuilder builder(module.getBody(), module.getBody()->begin()); - StringAttr includeAttr = - builder.getStringAttr(options.lowerToCpp ? cppStandardLibraryHeader - : cStandardLibraryHeader); - builder.create( - module.getLoc(), includeAttr, - /*is_standard_include=*/builder.getUnitAttr()); - return mlir::WalkResult::interrupt(); + return mlir::WalkResult::advance(); }); } }; diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir new file mode 100644 index 0000000000000..c1627a0d4d023 --- /dev/null +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir @@ -0,0 +1,50 @@ +// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=true" %s -split-input-file | FileCheck %s --check-prefix=CPP +// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=false" %s -split-input-file | FileCheck %s --check-prefix=NOCPP + +func.func @alloc_copy(%arg0: memref<999xi32>) { + %alloc = memref.alloc() : memref<999xi32> + memref.copy %arg0, %alloc : memref<999xi32> to memref<999xi32> + %alloc_1 = memref.alloc() : memref<999xi32> + memref.copy %arg0, %alloc_1 : memref<999xi32> to memref<999xi32> + return +} + +// CHECK: module { +// NOCPP: emitc.include <"stdlib.h"> +// NOCPP-NEXT: emitc.include <"string.h"> + +// CPP: emitc.include <"cstdlib"> +// CPP-NEXT: emitc.include <"cstring"> + +// CHECK-LABEL: alloc_copy +// CHECK-SAME: %[[arg0:.*]]: memref<999xi32> +// CHECK-NEXT: builtin.unrealized_conversion_cast %arg0 : memref<999xi32> to !emitc.array<999xi32> +// CHECK-NEXT: emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t +// CHECK-NEXT: "emitc.constant"() <{value = 999 : index}> : () -> index +// CHECK-NEXT: emitc.mul %1, %2 : (!emitc.size_t, index) -> !emitc.size_t +// CHECK-NEXT: emitc.call_opaque "malloc"(%3) : (!emitc.size_t) -> !emitc.ptr> +// CHECK-NEXT: emitc.cast %4 : !emitc.ptr> to !emitc.ptr +// CHECK-NEXT: builtin.unrealized_conversion_cast %5 : !emitc.ptr to !emitc.array<999xi32> +// CHECK-NEXT: "emitc.constant"() <{value = 0 : index}> : () -> index +// CHECK-NEXT: emitc.subscript %0[%7] : (!emitc.array<999xi32>, index) -> !emitc.lvalue +// CHECK-NEXT: emitc.apply "&"(%8) : (!emitc.lvalue) -> !emitc.ptr +// CHECK-NEXT: emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t +// CHECK-NEXT: "emitc.constant"() <{value = 999 : index}> : () -> index +// CHECK-NEXT: emitc.mul %12, %13 : (!emitc.size_t, index) -> !emitc.size_t +// CHECK-NEXT: emitc.call_opaque "memcpy"(%11, %9, %14) : (!emitc.ptr, !emitc.ptr, !emitc.size_t) -> () +// CHECK-NEXT: emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t +// CHECK-NEXT: "emitc.constant"() <{value = 999 : index}> : () -> index +// CHECK-NEXT: emitc.mul %15, %16 : (!emitc.size_t, index) -> !emitc.size_t +// CHECK-NEXT: emitc.call_opaque "malloc"(%17) : (!emitc.size_t) -> !emitc.ptr> +// CHECK-NEXT: emitc.cast %18 : !emitc.ptr> to !emitc.ptr +// CHECK-NEXT: builtin.unrealized_conversion_cast %19 : !emitc.ptr to !emitc.array<999xi32> +// CHECK-NEXT: "emitc.constant"() <{value = 0 : index}> : () -> index +// CHECK-NEXT: emitc.subscript %0[%21] : (!emitc.array<999xi32>, index) -> !emitc.lvalue +// CHECK-NEXT: emitc.apply "&"(%22) : (!emitc.lvalue) -> !emitc.ptr +// CHECK-NEXT: emitc.subscript %20[%21] : (!emitc.array<999xi32>, index) -> !emitc.lvalue +// CHECK-NEXT: emitc.apply "&"(%24) : (!emitc.lvalue) -> !emitc.ptr +// CHECK-NEXT: emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t +// CHECK-NEXT: "emitc.constant"() <{value = 999 : index}> : () -> index +// CHECK-NEXT: emitc.mul %26, %27 : (!emitc.size_t, index) -> !emitc.size_t +// CHECK-NEXT: emitc.call_opaque "memcpy"(%25, %23, %28) : (!emitc.ptr, !emitc.ptr, !emitc.size_t) -> () +// CHECK-NEXT: return diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir new file mode 100644 index 0000000000000..d151d1bd53458 --- /dev/null +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir @@ -0,0 +1,29 @@ +// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=true" %s -split-input-file | FileCheck %s --check-prefix=CPP +// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=false" %s -split-input-file | FileCheck %s --check-prefix=NOCPP + +func.func @copying(%arg0 : memref<9x4x5x7xf32>, %arg1 : memref<9x4x5x7xf32>) { + memref.copy %arg0, %arg1 : memref<9x4x5x7xf32> to memref<9x4x5x7xf32> + return +} + +// CHECK: module { +// NOCPP: emitc.include <"string.h"> +// CPP: emitc.include <"cstring"> + +// CHECK-LABEL: copying +// CHECK-SAME: %[[arg0:.*]]: memref<9x4x5x7xf32>, %[[arg1:.*]]: memref<9x4x5x7xf32> +// CHECK-NEXT: %0 = builtin.unrealized_conversion_cast %arg1 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32> +// CHECK-NEXT: %1 = builtin.unrealized_conversion_cast %arg0 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32> +// CHECK-NEXT: %2 = "emitc.constant"() <{value = 0 : index}> : () -> index +// CHECK-NEXT: %3 = emitc.subscript %1[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue +// CHECK-NEXT: %4 = emitc.apply "&"(%3) : (!emitc.lvalue) -> !emitc.ptr +// CHECK-NEXT: %5 = emitc.subscript %0[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue +// CHECK-NEXT: %6 = emitc.apply "&"(%5) : (!emitc.lvalue) -> !emitc.ptr +// CHECK-NEXT: %7 = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t +// CHECK-NEXT: %8 = "emitc.constant"() <{value = 1260 : index}> : () -> index +// CHECK-NEXT: %9 = emitc.mul %7, %8 : (!emitc.size_t, index) -> !emitc.size_t +// CHECK-NEXT: emitc.call_opaque "memcpy"(%6, %4, %9) : (!emitc.ptr, !emitc.ptr, !emitc.size_t) -> () +// CHECK-NEXT: return +// CHECK-NEXT: } +// CHECK-NEXT:} + diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir index fda01974d3fc8..b6eccfc8f0050 100644 --- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir @@ -1,13 +1,5 @@ // RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file -verify-diagnostics -func.func @memref_op(%arg0 : memref<2x4xf32>) { - // expected-error@+1 {{failed to legalize operation 'memref.copy'}} - memref.copy %arg0, %arg0 : memref<2x4xf32> to memref<2x4xf32> - return -} - -// ----- - func.func @alloca_with_dynamic_shape() { %0 = index.constant 1 // expected-error@+1 {{failed to legalize operation 'memref.alloca'}}