From c6ea99574fde268a83f20642f507aafe1ac6edf6 Mon Sep 17 00:00:00 2001 From: Jaddyen Date: Wed, 6 Aug 2025 20:18:38 +0000 Subject: [PATCH 1/3] needs improvment --- .../MemRefToEmitC/MemRefToEmitC.cpp | 84 ++++++++++++++++++- 1 file changed, 83 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index 6bd0e2d4d4b08..60cccb0ece8f2 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -21,7 +21,9 @@ #include "mlir/IR/TypeRange.h" #include "mlir/IR/Value.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/FormatVariadic.h" #include +#include using namespace mlir; @@ -269,6 +271,85 @@ struct ConvertLoad final : public OpConversionPattern { } }; +struct ConvertReinterpretCastOp final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + MemRefType srcType = cast(castOp.getSource().getType()); + + MemRefType targetMemRefType = + cast(castOp.getResult().getType()); + + auto srcInEmitC = convertMemRefType(srcType, getTypeConverter()); + auto targetInEmitC = + convertMemRefType(targetMemRefType, getTypeConverter()); + if (!srcInEmitC || !targetInEmitC) { + return rewriter.notifyMatchFailure(castOp.getLoc(), + "cannot convert memref type"); + } + Location loc = castOp.getLoc(); + + auto srcArrayValue = + cast>(adaptor.getSource()); + + emitc::ConstantOp zeroIndex = rewriter.create( + loc, rewriter.getIndexType(), rewriter.getIndexAttr(0)); + + auto createPointerFromEmitcArray = + [loc, &rewriter, &zeroIndex]( + mlir::TypedValue arrayValue) -> emitc::ApplyOp { + int64_t rank = arrayValue.getType().getRank(); + llvm::SmallVector indices; + for (int i = 0; i < rank; ++i) { + indices.push_back(zeroIndex); + } + + emitc::SubscriptOp subPtr = rewriter.create( + loc, arrayValue, mlir::ValueRange(indices)); + emitc::ApplyOp ptr = rewriter.create( + loc, emitc::PointerType::get(arrayValue.getType().getElementType()), + rewriter.getStringAttr("&"), subPtr); + + return ptr; + }; + auto [strides, offset] = targetMemRefType.getStridesAndOffset(); + // Value offsetValue = rewriter.create( + // loc, rewriter.getIndexType(), rewriter.getIndexAttr(offset)); + + auto srcPtr = createPointerFromEmitcArray(srcArrayValue); + // emitc::PointerType targetPointerType = + // emitc::PointerType::get(srcArrayValue.getType().getElementType()); + + auto dimensions = targetMemRefType.getShape(); + std::string reinterpretCastName = llvm::formatv( + "reinterpret_cast<{0}(*)", srcArrayValue.getType().getElementType()); + std::string dimensionsStr; + for (auto dim : dimensions) { + dimensionsStr += llvm::formatv("[{0}]", dim); + } + reinterpretCastName += llvm::formatv("{0}>", dimensionsStr); + reinterpretCastName += ">"; + + reinterpretCastName += llvm::formatv("{0}", srcPtr->getResult(0)); + + std::string outputStr = llvm::formatv( + "{0}(*){1}", srcArrayValue.getType().getElementType(), dimensionsStr); + auto outputType = emitc::PointerType::get( + emitc::OpaqueType::get(rewriter.getContext(), outputStr)); + + emitc::ConstantOp reinterpretOp = rewriter.create( + loc, outputType, + emitc::OpaqueAttr::get(rewriter.getContext(), reinterpretCastName)); + + rewriter.replaceOp(castOp, reinterpretOp.getResult()); + return success(); + } +}; + struct ConvertStore final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -321,5 +402,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { void mlir::populateMemRefToEmitCConversionPatterns( RewritePatternSet &patterns, const TypeConverter &converter) { patterns.add(converter, patterns.getContext()); + ConvertLoad, ConvertReinterpretCastOp, ConvertStore>( + converter, patterns.getContext()); } From d09ccc0fbcf77ac09bafbc9b96b4bc6332c7f02a Mon Sep 17 00:00:00 2001 From: Jaddyen Date: Thu, 7 Aug 2025 23:06:46 +0000 Subject: [PATCH 2/3] almost functional option --- .../MemRefToEmitC/MemRefToEmitC.cpp | 41 ++++++++----------- mlir/lib/Target/Cpp/TranslateToCpp.cpp | 13 +++++- .../memref-to-emitc-reinterpret-cast.mlir | 16 ++++++++ 3 files changed, 43 insertions(+), 27 deletions(-) create mode 100644 mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-reinterpret-cast.mlir diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index 60cccb0ece8f2..97b926861482f 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -20,6 +20,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeRange.h" #include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/Support/FormatVariadic.h" #include @@ -316,36 +317,26 @@ struct ConvertReinterpretCastOp final return ptr; }; - auto [strides, offset] = targetMemRefType.getStridesAndOffset(); - // Value offsetValue = rewriter.create( - // loc, rewriter.getIndexType(), rewriter.getIndexAttr(offset)); auto srcPtr = createPointerFromEmitcArray(srcArrayValue); - // emitc::PointerType targetPointerType = - // emitc::PointerType::get(srcArrayValue.getType().getElementType()); - - auto dimensions = targetMemRefType.getShape(); - std::string reinterpretCastName = llvm::formatv( - "reinterpret_cast<{0}(*)", srcArrayValue.getType().getElementType()); - std::string dimensionsStr; - for (auto dim : dimensions) { - dimensionsStr += llvm::formatv("[{0}]", dim); - } - reinterpretCastName += llvm::formatv("{0}>", dimensionsStr); - reinterpretCastName += ">"; - - reinterpretCastName += llvm::formatv("{0}", srcPtr->getResult(0)); + // 1. Create a TypeAttr for the target type. + TypeAttr targetTypeAttr = + TypeAttr::get(emitc::PointerType::get(targetInEmitC)); + IntegerAttr resty = rewriter.getIndexAttr(0); - std::string outputStr = llvm::formatv( - "{0}(*){1}", srcArrayValue.getType().getElementType(), dimensionsStr); - auto outputType = emitc::PointerType::get( - emitc::OpaqueType::get(rewriter.getContext(), outputStr)); + // 2. Create an ArrayAttr with the TypeAttr. This will be the + // templateArgsAttr. + ArrayAttr templateArgsAttr = rewriter.getArrayAttr({targetTypeAttr}); - emitc::ConstantOp reinterpretOp = rewriter.create( - loc, outputType, - emitc::OpaqueAttr::get(rewriter.getContext(), reinterpretCastName)); + auto reinterpretCastCall = rewriter.create( + loc, + /*result types=*/TypeRange{emitc::PointerType::get(targetInEmitC)}, + /*callee=*/"reinterpret_cast", + /*args*/ rewriter.getArrayAttr({resty}), + /*template_args=*/templateArgsAttr, + /*operands=*/ValueRange{srcPtr.getResult()}); - rewriter.replaceOp(castOp, reinterpretOp.getResult()); + rewriter.replaceOp(castOp, reinterpretCastCall.getResults()); return success(); } }; diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index 8e83e455d1a7f..3d47da8c439de 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -1848,8 +1848,17 @@ LogicalResult CppEmitter::emitType(Location loc, Type type) { if (auto lType = dyn_cast(type)) return emitType(loc, lType.getValueType()); if (auto pType = dyn_cast(type)) { - if (isa(pType.getPointee())) - return emitError(loc, "cannot emit pointer to array type ") << type; + // Check if the pointee is an array type. + if (auto aType = dyn_cast(pType.getPointee())) { + // Handle pointer to array: `element_type (*)[dim]`. + if (failed(emitType(loc, aType.getElementType()))) + return failure(); + os << "(*)"; + for (auto dim : aType.getShape()) + os << "[" << dim << "]"; + return success(); + } + // Handle standard pointer: `element_type*`. if (failed(emitType(loc, pType.getPointee()))) return failure(); os << "*"; diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-reinterpret-cast.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-reinterpret-cast.mlir new file mode 100644 index 0000000000000..5d610fe1ae642 --- /dev/null +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-reinterpret-cast.mlir @@ -0,0 +1,16 @@ +// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file | FileCheck %s + +func.func @casting(%arg0: memref<999xi32>) { + %reinterpret_cast_5 = memref.reinterpret_cast %arg0 to offset: [0], sizes: [1, 1, 999], strides: [999, 999, 1] : memref<999xi32> to memref<1x1x999xi32> + return +} + +//CHECK: module { +//CHECK-NEXT: func.func @casting(%arg0: memref<999xi32>) { +//CHECK-NEXT: %0 = builtin.unrealized_conversion_cast %arg0 : memref<999xi32> to !emitc.array<999xi32> +//CHECK-NEXT: %1 = "emitc.constant"() <{value = 0 : index}> : () -> index +//CHECK-NEXT: %2 = emitc.subscript %0[%1] : (!emitc.array<999xi32>, index) -> !emitc.lvalue +//CHECK-NEXT: %3 = emitc.apply "&"(%2) : (!emitc.lvalue) -> !emitc.ptr +//CHECK-NEXT: %4 = emitc.call_opaque "reinterpret_cast"(%3) {args = [0 : index], template_args = [!emitc.ptr>]} : (!emitc.ptr) -> !emitc.ptr> +//CHECK-NEXT: return + From 19e23d560b00008e664c2e85df9840ae71f51ee1 Mon Sep 17 00:00:00 2001 From: Jaddyen Date: Fri, 8 Aug 2025 21:08:08 +0000 Subject: [PATCH 3/3] cpp output that compiles --- .../MemRefToEmitC/MemRefToEmitC.cpp | 19 +++---------------- mlir/lib/Target/Cpp/TranslateToCpp.cpp | 14 ++++++++++++++ 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index 97b926861482f..18bd79de2ec2a 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -319,24 +319,11 @@ struct ConvertReinterpretCastOp final }; auto srcPtr = createPointerFromEmitcArray(srcArrayValue); - // 1. Create a TypeAttr for the target type. - TypeAttr targetTypeAttr = - TypeAttr::get(emitc::PointerType::get(targetInEmitC)); - IntegerAttr resty = rewriter.getIndexAttr(0); - // 2. Create an ArrayAttr with the TypeAttr. This will be the - // templateArgsAttr. - ArrayAttr templateArgsAttr = rewriter.getArrayAttr({targetTypeAttr}); + auto castCall = rewriter.create( + loc, emitc::PointerType::get(targetInEmitC), srcPtr.getResult()); - auto reinterpretCastCall = rewriter.create( - loc, - /*result types=*/TypeRange{emitc::PointerType::get(targetInEmitC)}, - /*callee=*/"reinterpret_cast", - /*args*/ rewriter.getArrayAttr({resty}), - /*template_args=*/templateArgsAttr, - /*operands=*/ValueRange{srcPtr.getResult()}); - - rewriter.replaceOp(castOp, reinterpretCastCall.getResults()); + rewriter.replaceOp(castOp, castCall); return success(); } }; diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index 3d47da8c439de..6ec1c070fde83 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -1756,6 +1756,20 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { LogicalResult CppEmitter::emitVariableDeclaration(Location loc, Type type, StringRef name) { + if (auto pType = dyn_cast(type)) { + if (auto aType = dyn_cast(pType.getPointee())) { + if (failed(emitType(loc, aType.getElementType()))) + return failure(); + os << " (*" << name << ")"; + for (auto dim : aType.getShape()) + os << "[" << dim << "]"; + return success(); + } + if (failed(emitType(loc, pType.getPointee()))) + return failure(); + os << " *" << name; + return success(); + } if (auto arrType = dyn_cast(type)) { if (failed(emitType(loc, arrType.getElementType()))) return failure();