diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index 6bd0e2d4d4b08..18bd79de2ec2a 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -20,8 +20,11 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeRange.h" #include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/FormatVariadic.h" #include +#include using namespace mlir; @@ -269,6 +272,62 @@ struct ConvertLoad final : public OpConversionPattern { } }; +struct ConvertReinterpretCastOp final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + MemRefType srcType = cast(castOp.getSource().getType()); + + MemRefType targetMemRefType = + cast(castOp.getResult().getType()); + + auto srcInEmitC = convertMemRefType(srcType, getTypeConverter()); + auto targetInEmitC = + convertMemRefType(targetMemRefType, getTypeConverter()); + if (!srcInEmitC || !targetInEmitC) { + return rewriter.notifyMatchFailure(castOp.getLoc(), + "cannot convert memref type"); + } + Location loc = castOp.getLoc(); + + auto srcArrayValue = + cast>(adaptor.getSource()); + + emitc::ConstantOp zeroIndex = rewriter.create( + loc, rewriter.getIndexType(), rewriter.getIndexAttr(0)); + + auto createPointerFromEmitcArray = + [loc, &rewriter, &zeroIndex]( + mlir::TypedValue arrayValue) -> emitc::ApplyOp { + int64_t rank = arrayValue.getType().getRank(); + llvm::SmallVector indices; + for (int i = 0; i < rank; ++i) { + indices.push_back(zeroIndex); + } + + emitc::SubscriptOp subPtr = rewriter.create( + loc, arrayValue, mlir::ValueRange(indices)); + emitc::ApplyOp ptr = rewriter.create( + loc, emitc::PointerType::get(arrayValue.getType().getElementType()), + rewriter.getStringAttr("&"), subPtr); + + return ptr; + }; + + auto srcPtr = createPointerFromEmitcArray(srcArrayValue); + + auto castCall = rewriter.create( + loc, emitc::PointerType::get(targetInEmitC), srcPtr.getResult()); + + rewriter.replaceOp(castOp, castCall); + return success(); + } +}; + struct ConvertStore final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -321,5 +380,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { void mlir::populateMemRefToEmitCConversionPatterns( RewritePatternSet &patterns, const TypeConverter &converter) { patterns.add(converter, patterns.getContext()); + ConvertLoad, ConvertReinterpretCastOp, ConvertStore>( + converter, patterns.getContext()); } diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index 8e83e455d1a7f..6ec1c070fde83 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -1756,6 +1756,20 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { LogicalResult CppEmitter::emitVariableDeclaration(Location loc, Type type, StringRef name) { + if (auto pType = dyn_cast(type)) { + if (auto aType = dyn_cast(pType.getPointee())) { + if (failed(emitType(loc, aType.getElementType()))) + return failure(); + os << " (*" << name << ")"; + for (auto dim : aType.getShape()) + os << "[" << dim << "]"; + return success(); + } + if (failed(emitType(loc, pType.getPointee()))) + return failure(); + os << " *" << name; + return success(); + } if (auto arrType = dyn_cast(type)) { if (failed(emitType(loc, arrType.getElementType()))) return failure(); @@ -1848,8 +1862,17 @@ LogicalResult CppEmitter::emitType(Location loc, Type type) { if (auto lType = dyn_cast(type)) return emitType(loc, lType.getValueType()); if (auto pType = dyn_cast(type)) { - if (isa(pType.getPointee())) - return emitError(loc, "cannot emit pointer to array type ") << type; + // Check if the pointee is an array type. + if (auto aType = dyn_cast(pType.getPointee())) { + // Handle pointer to array: `element_type (*)[dim]`. + if (failed(emitType(loc, aType.getElementType()))) + return failure(); + os << "(*)"; + for (auto dim : aType.getShape()) + os << "[" << dim << "]"; + return success(); + } + // Handle standard pointer: `element_type*`. if (failed(emitType(loc, pType.getPointee()))) return failure(); os << "*"; diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-reinterpret-cast.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-reinterpret-cast.mlir new file mode 100644 index 0000000000000..5d610fe1ae642 --- /dev/null +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-reinterpret-cast.mlir @@ -0,0 +1,16 @@ +// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file | FileCheck %s + +func.func @casting(%arg0: memref<999xi32>) { + %reinterpret_cast_5 = memref.reinterpret_cast %arg0 to offset: [0], sizes: [1, 1, 999], strides: [999, 999, 1] : memref<999xi32> to memref<1x1x999xi32> + return +} + +//CHECK: module { +//CHECK-NEXT: func.func @casting(%arg0: memref<999xi32>) { +//CHECK-NEXT: %0 = builtin.unrealized_conversion_cast %arg0 : memref<999xi32> to !emitc.array<999xi32> +//CHECK-NEXT: %1 = "emitc.constant"() <{value = 0 : index}> : () -> index +//CHECK-NEXT: %2 = emitc.subscript %0[%1] : (!emitc.array<999xi32>, index) -> !emitc.lvalue +//CHECK-NEXT: %3 = emitc.apply "&"(%2) : (!emitc.lvalue) -> !emitc.ptr +//CHECK-NEXT: %4 = emitc.call_opaque "reinterpret_cast"(%3) {args = [0 : index], template_args = [!emitc.ptr>]} : (!emitc.ptr) -> !emitc.ptr> +//CHECK-NEXT: return +