Skip to content

Commit c6ea995

Browse files
committed
needs improvment
1 parent df392b5 commit c6ea995

File tree

1 file changed

+83
-1
lines changed

1 file changed

+83
-1
lines changed

mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
#include "mlir/IR/TypeRange.h"
2222
#include "mlir/IR/Value.h"
2323
#include "mlir/Transforms/DialectConversion.h"
24+
#include "llvm/Support/FormatVariadic.h"
2425
#include <cstdint>
26+
#include <string>
2527

2628
using namespace mlir;
2729

@@ -269,6 +271,85 @@ struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
269271
}
270272
};
271273

274+
struct ConvertReinterpretCastOp final
275+
: public OpConversionPattern<memref::ReinterpretCastOp> {
276+
using OpConversionPattern::OpConversionPattern;
277+
278+
LogicalResult
279+
matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
280+
ConversionPatternRewriter &rewriter) const override {
281+
282+
MemRefType srcType = cast<MemRefType>(castOp.getSource().getType());
283+
284+
MemRefType targetMemRefType =
285+
cast<MemRefType>(castOp.getResult().getType());
286+
287+
auto srcInEmitC = convertMemRefType(srcType, getTypeConverter());
288+
auto targetInEmitC =
289+
convertMemRefType(targetMemRefType, getTypeConverter());
290+
if (!srcInEmitC || !targetInEmitC) {
291+
return rewriter.notifyMatchFailure(castOp.getLoc(),
292+
"cannot convert memref type");
293+
}
294+
Location loc = castOp.getLoc();
295+
296+
auto srcArrayValue =
297+
cast<TypedValue<emitc::ArrayType>>(adaptor.getSource());
298+
299+
emitc::ConstantOp zeroIndex = rewriter.create<emitc::ConstantOp>(
300+
loc, rewriter.getIndexType(), rewriter.getIndexAttr(0));
301+
302+
auto createPointerFromEmitcArray =
303+
[loc, &rewriter, &zeroIndex](
304+
mlir::TypedValue<emitc::ArrayType> arrayValue) -> emitc::ApplyOp {
305+
int64_t rank = arrayValue.getType().getRank();
306+
llvm::SmallVector<mlir::Value> indices;
307+
for (int i = 0; i < rank; ++i) {
308+
indices.push_back(zeroIndex);
309+
}
310+
311+
emitc::SubscriptOp subPtr = rewriter.create<emitc::SubscriptOp>(
312+
loc, arrayValue, mlir::ValueRange(indices));
313+
emitc::ApplyOp ptr = rewriter.create<emitc::ApplyOp>(
314+
loc, emitc::PointerType::get(arrayValue.getType().getElementType()),
315+
rewriter.getStringAttr("&"), subPtr);
316+
317+
return ptr;
318+
};
319+
auto [strides, offset] = targetMemRefType.getStridesAndOffset();
320+
// Value offsetValue = rewriter.create<emitc::ConstantOp>(
321+
// loc, rewriter.getIndexType(), rewriter.getIndexAttr(offset));
322+
323+
auto srcPtr = createPointerFromEmitcArray(srcArrayValue);
324+
// emitc::PointerType targetPointerType =
325+
// emitc::PointerType::get(srcArrayValue.getType().getElementType());
326+
327+
auto dimensions = targetMemRefType.getShape();
328+
std::string reinterpretCastName = llvm::formatv(
329+
"reinterpret_cast<{0}(*)", srcArrayValue.getType().getElementType());
330+
std::string dimensionsStr;
331+
for (auto dim : dimensions) {
332+
dimensionsStr += llvm::formatv("[{0}]", dim);
333+
}
334+
reinterpretCastName += llvm::formatv("{0}>", dimensionsStr);
335+
reinterpretCastName += ">";
336+
337+
reinterpretCastName += llvm::formatv("{0}", srcPtr->getResult(0));
338+
339+
std::string outputStr = llvm::formatv(
340+
"{0}(*){1}", srcArrayValue.getType().getElementType(), dimensionsStr);
341+
auto outputType = emitc::PointerType::get(
342+
emitc::OpaqueType::get(rewriter.getContext(), outputStr));
343+
344+
emitc::ConstantOp reinterpretOp = rewriter.create<emitc::ConstantOp>(
345+
loc, outputType,
346+
emitc::OpaqueAttr::get(rewriter.getContext(), reinterpretCastName));
347+
348+
rewriter.replaceOp(castOp, reinterpretOp.getResult());
349+
return success();
350+
}
351+
};
352+
272353
struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
273354
using OpConversionPattern::OpConversionPattern;
274355

@@ -321,5 +402,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
321402
void mlir::populateMemRefToEmitCConversionPatterns(
322403
RewritePatternSet &patterns, const TypeConverter &converter) {
323404
patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal,
324-
ConvertLoad, ConvertStore>(converter, patterns.getContext());
405+
ConvertLoad, ConvertReinterpretCastOp, ConvertStore>(
406+
converter, patterns.getContext());
325407
}

0 commit comments

Comments
 (0)