-
Couldn't load subscription status.
- Fork 15k
[mlir][bufferization] Test tensor encoding -> memref layout conversion #161166
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ece4805
a6c3ef2
eb51e55
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,7 +24,10 @@ def Test_Dialect : Dialect { | |
| let useDefaultTypePrinterParser = 0; | ||
| let useDefaultAttributePrinterParser = 1; | ||
| let isExtensible = 1; | ||
| let dependentDialects = ["::mlir::DLTIDialect"]; | ||
| let dependentDialects = [ | ||
| "::mlir::DLTIDialect", | ||
| "::mlir::bufferization::BufferizationDialect" | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seems required, otherwise some |
||
| ]; | ||
| let discardableAttrs = (ins | ||
| "mlir::IntegerAttr":$discardable_attr_key, | ||
| "SimpleAAttr":$other_discardable_attr_key | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1425,6 +1425,39 @@ TestMultiSlotAlloca::handleDestructuringComplete( | |
| return createNewMultiAllocaWithoutSlot(slot, builder, *this); | ||
| } | ||
|
|
||
| namespace { | ||
| /// Returns test dialect's memref layout for test dialect's tensor encoding when | ||
| /// applicable. | ||
| MemRefLayoutAttrInterface | ||
| getMemRefLayoutForTensorEncoding(RankedTensorType tensorType) { | ||
| if (auto encoding = | ||
| dyn_cast<test::TestTensorEncodingAttr>(tensorType.getEncoding())) { | ||
| return cast<MemRefLayoutAttrInterface>(test::TestMemRefLayoutAttr::get( | ||
| tensorType.getContext(), encoding.getDummy())); | ||
| } | ||
| return {}; | ||
| } | ||
|
|
||
| /// Auxiliary bufferization function for test and builtin tensors. | ||
| bufferization::BufferLikeType | ||
| convertTensorToBuffer(mlir::Operation *op, | ||
| const bufferization::BufferizationOptions &options, | ||
| bufferization::TensorLikeType tensorLike) { | ||
| auto buffer = | ||
| *tensorLike.getBufferType(options, [&]() { return op->emitError(); }); | ||
| if (auto memref = dyn_cast<MemRefType>(buffer)) { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. note: if we'd have a option callback that provides customizable layout inference, this branch could be avoided. instead, the one-shot bufferization options could be configured and this whole thing becomes just |
||
| // Note: For the sake of testing, we want to ensure that encoding -> layout | ||
| // bufferization happens. This is currently achieved manually. | ||
| auto layout = | ||
| getMemRefLayoutForTensorEncoding(cast<RankedTensorType>(tensorLike)); | ||
| return cast<bufferization::BufferLikeType>( | ||
| MemRefType::get(memref.getShape(), memref.getElementType(), layout, | ||
| memref.getMemorySpace())); | ||
| } | ||
| return buffer; | ||
| } | ||
| } // namespace | ||
|
|
||
| ::mlir::LogicalResult test::TestDummyTensorOp::bufferize( | ||
| ::mlir::RewriterBase &rewriter, | ||
| const ::mlir::bufferization::BufferizationOptions &options, | ||
|
|
@@ -1435,8 +1468,8 @@ ::mlir::LogicalResult test::TestDummyTensorOp::bufferize( | |
| return failure(); | ||
|
|
||
| const auto outType = getOutput().getType(); | ||
| const auto bufferizedOutType = test::TestMemrefType::get( | ||
| getContext(), outType.getShape(), outType.getElementType(), nullptr); | ||
| const auto bufferizedOutType = | ||
| convertTensorToBuffer(getOperation(), options, outType); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's better to call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do you mean Op::getBufferType? note that this is another op from the one below (dummy vs create) where i don't overwrite ::getBufferType(). the "create tensor" one already does what you suggest here. |
||
| // replace op with memref analogy | ||
| auto dummyMemrefOp = test::TestDummyMemrefOp::create( | ||
| rewriter, getLoc(), bufferizedOutType, *buffer); | ||
|
|
@@ -1470,13 +1503,12 @@ ::mlir::LogicalResult test::TestCreateTensorOp::bufferize( | |
|
|
||
| mlir::FailureOr<mlir::bufferization::BufferLikeType> | ||
| test::TestCreateTensorOp::getBufferType( | ||
| mlir::Value value, const mlir::bufferization::BufferizationOptions &, | ||
| mlir::Value value, const mlir::bufferization::BufferizationOptions &options, | ||
| const mlir::bufferization::BufferizationState &, | ||
| llvm::SmallVector<::mlir::Value> &) { | ||
| const auto type = dyn_cast<test::TestTensorType>(value.getType()); | ||
| const auto type = dyn_cast<bufferization::TensorLikeType>(value.getType()); | ||
| if (type == nullptr) | ||
| return failure(); | ||
|
|
||
| return cast<mlir::bufferization::BufferLikeType>(test::TestMemrefType::get( | ||
| getContext(), type.getShape(), type.getElementType(), nullptr)); | ||
| return convertTensorToBuffer(getOperation(), options, type); | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note: i don't know why, but somehow running the test locally for me failed without this fix. given that mlir-opt produces
memref<?xf32>for me, I'm not sure how this works in main right now :|