-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[mlir][EmitC] Expand the MemRefToEmitC pass - Lowering reinterpret_cast
#152610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -20,8 +20,11 @@ | |||||||||||
#include "mlir/IR/PatternMatch.h" | ||||||||||||
#include "mlir/IR/TypeRange.h" | ||||||||||||
#include "mlir/IR/Value.h" | ||||||||||||
#include "mlir/IR/ValueRange.h" | ||||||||||||
#include "mlir/Transforms/DialectConversion.h" | ||||||||||||
#include "llvm/Support/FormatVariadic.h" | ||||||||||||
#include <cstdint> | ||||||||||||
#include <string> | ||||||||||||
|
||||||||||||
using namespace mlir; | ||||||||||||
|
||||||||||||
|
@@ -269,6 +272,62 @@ struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> { | |||||||||||
} | ||||||||||||
}; | ||||||||||||
|
||||||||||||
struct ConvertReinterpretCastOp final | ||||||||||||
: public OpConversionPattern<memref::ReinterpretCastOp> { | ||||||||||||
using OpConversionPattern::OpConversionPattern; | ||||||||||||
|
||||||||||||
LogicalResult | ||||||||||||
matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor, | ||||||||||||
ConversionPatternRewriter &rewriter) const override { | ||||||||||||
|
||||||||||||
MemRefType srcType = cast<MemRefType>(castOp.getSource().getType()); | ||||||||||||
|
||||||||||||
MemRefType targetMemRefType = | ||||||||||||
cast<MemRefType>(castOp.getResult().getType()); | ||||||||||||
|
||||||||||||
auto srcInEmitC = convertMemRefType(srcType, getTypeConverter()); | ||||||||||||
auto targetInEmitC = | ||||||||||||
convertMemRefType(targetMemRefType, getTypeConverter()); | ||||||||||||
if (!srcInEmitC || !targetInEmitC) { | ||||||||||||
return rewriter.notifyMatchFailure(castOp.getLoc(), | ||||||||||||
"cannot convert memref type"); | ||||||||||||
} | ||||||||||||
Location loc = castOp.getLoc(); | ||||||||||||
|
||||||||||||
auto srcArrayValue = | ||||||||||||
cast<TypedValue<emitc::ArrayType>>(adaptor.getSource()); | ||||||||||||
|
||||||||||||
emitc::ConstantOp zeroIndex = rewriter.create<emitc::ConstantOp>( | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using |
||||||||||||
loc, rewriter.getIndexType(), rewriter.getIndexAttr(0)); | ||||||||||||
Comment on lines
+300
to
+301
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I've seen a zeroIndex crated in several places now. Its not much code, but that may be a good candidate for a helper function (e.g. in the anonymous namespace and marked static). |
||||||||||||
|
||||||||||||
auto createPointerFromEmitcArray = | ||||||||||||
[loc, &rewriter, &zeroIndex]( | ||||||||||||
mlir::TypedValue<emitc::ArrayType> arrayValue) -> emitc::ApplyOp { | ||||||||||||
int64_t rank = arrayValue.getType().getRank(); | ||||||||||||
llvm::SmallVector<mlir::Value> indices; | ||||||||||||
for (int i = 0; i < rank; ++i) { | ||||||||||||
indices.push_back(zeroIndex); | ||||||||||||
} | ||||||||||||
Comment on lines
+307
to
+310
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Comment on lines
+303
to
+310
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This closure seems familar from your other patches. I'm guessing that means it should probably be a helper function instead, since I don't see anything in the capture list that couldn't be a parameter. |
||||||||||||
|
||||||||||||
emitc::SubscriptOp subPtr = rewriter.create<emitc::SubscriptOp>( | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you switch to the new builder form ( |
||||||||||||
loc, arrayValue, mlir::ValueRange(indices)); | ||||||||||||
emitc::ApplyOp ptr = rewriter.create<emitc::ApplyOp>( | ||||||||||||
loc, emitc::PointerType::get(arrayValue.getType().getElementType()), | ||||||||||||
rewriter.getStringAttr("&"), subPtr); | ||||||||||||
|
||||||||||||
return ptr; | ||||||||||||
}; | ||||||||||||
|
||||||||||||
auto srcPtr = createPointerFromEmitcArray(srcArrayValue); | ||||||||||||
|
||||||||||||
auto castCall = rewriter.create<emitc::CastOp>( | ||||||||||||
loc, emitc::PointerType::get(targetInEmitC), srcPtr.getResult()); | ||||||||||||
|
||||||||||||
rewriter.replaceOp(castOp, castCall); | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can combine this with the preceding by doing rewriter.replaceWithNewOp |
||||||||||||
return success(); | ||||||||||||
} | ||||||||||||
}; | ||||||||||||
|
||||||||||||
struct ConvertStore final : public OpConversionPattern<memref::StoreOp> { | ||||||||||||
using OpConversionPattern::OpConversionPattern; | ||||||||||||
|
||||||||||||
|
@@ -321,5 +380,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { | |||||||||||
void mlir::populateMemRefToEmitCConversionPatterns( | ||||||||||||
RewritePatternSet &patterns, const TypeConverter &converter) { | ||||||||||||
patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal, | ||||||||||||
ConvertLoad, ConvertStore>(converter, patterns.getContext()); | ||||||||||||
ConvertLoad, ConvertReinterpretCastOp, ConvertStore>( | ||||||||||||
converter, patterns.getContext()); | ||||||||||||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1756,6 +1756,20 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { | |
|
||
LogicalResult CppEmitter::emitVariableDeclaration(Location loc, Type type, | ||
StringRef name) { | ||
if (auto pType = dyn_cast<emitc::PointerType>(type)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could this be a standalone PR? (this feels like it is a general refinement). |
||
if (auto aType = dyn_cast<emitc::ArrayType>(pType.getPointee())) { | ||
if (failed(emitType(loc, aType.getElementType()))) | ||
return failure(); | ||
os << " (*" << name << ")"; | ||
for (auto dim : aType.getShape()) | ||
os << "[" << dim << "]"; | ||
return success(); | ||
} | ||
if (failed(emitType(loc, pType.getPointee()))) | ||
return failure(); | ||
os << " *" << name; | ||
return success(); | ||
} | ||
if (auto arrType = dyn_cast<emitc::ArrayType>(type)) { | ||
if (failed(emitType(loc, arrType.getElementType()))) | ||
return failure(); | ||
|
@@ -1848,8 +1862,17 @@ LogicalResult CppEmitter::emitType(Location loc, Type type) { | |
if (auto lType = dyn_cast<emitc::LValueType>(type)) | ||
return emitType(loc, lType.getValueType()); | ||
if (auto pType = dyn_cast<emitc::PointerType>(type)) { | ||
if (isa<ArrayType>(pType.getPointee())) | ||
return emitError(loc, "cannot emit pointer to array type ") << type; | ||
// Check if the pointee is an array type. | ||
if (auto aType = dyn_cast<emitc::ArrayType>(pType.getPointee())) { | ||
// Handle pointer to array: `element_type (*)[dim]`. | ||
if (failed(emitType(loc, aType.getElementType()))) | ||
return failure(); | ||
os << "(*)"; | ||
for (auto dim : aType.getShape()) | ||
os << "[" << dim << "]"; | ||
return success(); | ||
} | ||
// Handle standard pointer: `element_type*`. | ||
if (failed(emitType(loc, pType.getPointee()))) | ||
return failure(); | ||
os << "*"; | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,16 @@ | ||||||
// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file | FileCheck %s | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
The |
||||||
|
||||||
func.func @casting(%arg0: memref<999xi32>) { | ||||||
%reinterpret_cast_5 = memref.reinterpret_cast %arg0 to offset: [0], sizes: [1, 1, 999], strides: [999, 999, 1] : memref<999xi32> to memref<1x1x999xi32> | ||||||
return | ||||||
} | ||||||
|
||||||
//CHECK: module { | ||||||
//CHECK-NEXT: func.func @casting(%arg0: memref<999xi32>) { | ||||||
//CHECK-NEXT: %0 = builtin.unrealized_conversion_cast %arg0 : memref<999xi32> to !emitc.array<999xi32> | ||||||
//CHECK-NEXT: %1 = "emitc.constant"() <{value = 0 : index}> : () -> index | ||||||
//CHECK-NEXT: %2 = emitc.subscript %0[%1] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32> | ||||||
//CHECK-NEXT: %3 = emitc.apply "&"(%2) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32> | ||||||
//CHECK-NEXT: %4 = emitc.call_opaque "reinterpret_cast"(%3) {args = [0 : index], template_args = [!emitc.ptr<!emitc.array<1x1x999xi32>>]} : (!emitc.ptr<i32>) -> !emitc.ptr<!emitc.array<1x1x999xi32>> | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. an actual There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1, good to check. |
||||||
//CHECK-NEXT: return | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Empty lines are for me similar to paragraphs in text, it creates logical separations/groupings which aid reading. Here I can't quite figure those out.