Skip to content

Commit d09ccc0

Browse files
committed
almost functional option
1 parent c6ea995 commit d09ccc0

File tree

3 files changed

+43
-27
lines changed

3 files changed

+43
-27
lines changed

mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "mlir/IR/PatternMatch.h"
2121
#include "mlir/IR/TypeRange.h"
2222
#include "mlir/IR/Value.h"
23+
#include "mlir/IR/ValueRange.h"
2324
#include "mlir/Transforms/DialectConversion.h"
2425
#include "llvm/Support/FormatVariadic.h"
2526
#include <cstdint>
@@ -316,36 +317,26 @@ struct ConvertReinterpretCastOp final
316317

317318
return ptr;
318319
};
319-
auto [strides, offset] = targetMemRefType.getStridesAndOffset();
320-
// Value offsetValue = rewriter.create<emitc::ConstantOp>(
321-
// loc, rewriter.getIndexType(), rewriter.getIndexAttr(offset));
322320

323321
auto srcPtr = createPointerFromEmitcArray(srcArrayValue);
324-
// emitc::PointerType targetPointerType =
325-
// emitc::PointerType::get(srcArrayValue.getType().getElementType());
326-
327-
auto dimensions = targetMemRefType.getShape();
328-
std::string reinterpretCastName = llvm::formatv(
329-
"reinterpret_cast<{0}(*)", srcArrayValue.getType().getElementType());
330-
std::string dimensionsStr;
331-
for (auto dim : dimensions) {
332-
dimensionsStr += llvm::formatv("[{0}]", dim);
333-
}
334-
reinterpretCastName += llvm::formatv("{0}>", dimensionsStr);
335-
reinterpretCastName += ">";
336-
337-
reinterpretCastName += llvm::formatv("{0}", srcPtr->getResult(0));
322+
// 1. Create a TypeAttr for the target type.
323+
TypeAttr targetTypeAttr =
324+
TypeAttr::get(emitc::PointerType::get(targetInEmitC));
325+
IntegerAttr resty = rewriter.getIndexAttr(0);
338326

339-
std::string outputStr = llvm::formatv(
340-
"{0}(*){1}", srcArrayValue.getType().getElementType(), dimensionsStr);
341-
auto outputType = emitc::PointerType::get(
342-
emitc::OpaqueType::get(rewriter.getContext(), outputStr));
327+
// 2. Create an ArrayAttr with the TypeAttr. This will be the
328+
// templateArgsAttr.
329+
ArrayAttr templateArgsAttr = rewriter.getArrayAttr({targetTypeAttr});
343330

344-
emitc::ConstantOp reinterpretOp = rewriter.create<emitc::ConstantOp>(
345-
loc, outputType,
346-
emitc::OpaqueAttr::get(rewriter.getContext(), reinterpretCastName));
331+
auto reinterpretCastCall = rewriter.create<emitc::CallOpaqueOp>(
332+
loc,
333+
/*result types=*/TypeRange{emitc::PointerType::get(targetInEmitC)},
334+
/*callee=*/"reinterpret_cast",
335+
/*args*/ rewriter.getArrayAttr({resty}),
336+
/*template_args=*/templateArgsAttr,
337+
/*operands=*/ValueRange{srcPtr.getResult()});
347338

348-
rewriter.replaceOp(castOp, reinterpretOp.getResult());
339+
rewriter.replaceOp(castOp, reinterpretCastCall.getResults());
349340
return success();
350341
}
351342
};

mlir/lib/Target/Cpp/TranslateToCpp.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1848,8 +1848,17 @@ LogicalResult CppEmitter::emitType(Location loc, Type type) {
18481848
if (auto lType = dyn_cast<emitc::LValueType>(type))
18491849
return emitType(loc, lType.getValueType());
18501850
if (auto pType = dyn_cast<emitc::PointerType>(type)) {
1851-
if (isa<ArrayType>(pType.getPointee()))
1852-
return emitError(loc, "cannot emit pointer to array type ") << type;
1851+
// Check if the pointee is an array type.
1852+
if (auto aType = dyn_cast<emitc::ArrayType>(pType.getPointee())) {
1853+
// Handle pointer to array: `element_type (*)[dim]`.
1854+
if (failed(emitType(loc, aType.getElementType())))
1855+
return failure();
1856+
os << "(*)";
1857+
for (auto dim : aType.getShape())
1858+
os << "[" << dim << "]";
1859+
return success();
1860+
}
1861+
// Handle standard pointer: `element_type*`.
18531862
if (failed(emitType(loc, pType.getPointee())))
18541863
return failure();
18551864
os << "*";
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file | FileCheck %s
2+
3+
func.func @casting(%arg0: memref<999xi32>) {
4+
%reinterpret_cast_5 = memref.reinterpret_cast %arg0 to offset: [0], sizes: [1, 1, 999], strides: [999, 999, 1] : memref<999xi32> to memref<1x1x999xi32>
5+
return
6+
}
7+
8+
//CHECK: module {
9+
//CHECK-NEXT: func.func @casting(%arg0: memref<999xi32>) {
10+
//CHECK-NEXT: %0 = builtin.unrealized_conversion_cast %arg0 : memref<999xi32> to !emitc.array<999xi32>
11+
//CHECK-NEXT: %1 = "emitc.constant"() <{value = 0 : index}> : () -> index
12+
//CHECK-NEXT: %2 = emitc.subscript %0[%1] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32>
13+
//CHECK-NEXT: %3 = emitc.apply "&"(%2) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
14+
//CHECK-NEXT: %4 = emitc.call_opaque "reinterpret_cast"(%3) {args = [0 : index], template_args = [!emitc.ptr<!emitc.array<1x1x999xi32>>]} : (!emitc.ptr<i32>) -> !emitc.ptr<!emitc.array<1x1x999xi32>>
15+
//CHECK-NEXT: return
16+

0 commit comments

Comments
 (0)