-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][draft] Support 1:N dialect conversion #112141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
You can test this locally with the following command:git-clang-format --diff de0fd64bedd23660f557833cc0108c3fb2be3918 ff124133d889c014f00a7739c267d56bf9312e18 --extensions cpp,h -- mlir/include/mlir/Conversion/LLVMCommon/Pattern.h mlir/include/mlir/Transforms/DialectConversion.h mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp mlir/lib/Transforms/Utils/DialectConversion.cpp mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp mlir/test/lib/Dialect/Test/TestPatterns.cppView the diff from clang-format here.diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 20a2a10ded..e310b926b9 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -153,70 +153,74 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
type.isVarArg());
});
-/*
- // Argument materializations convert from the new block argument types
- // (multiple SSA values that make up a memref descriptor) back to the
- // original block argument type. The dialect conversion framework will then
- // insert a target materialization from the original block argument type to
- // a legal type.
- addArgumentMaterialization([&](OpBuilder &builder,
- UnrankedMemRefType resultType,
- ValueRange inputs, Location loc) {
- if (inputs.size() == 1) {
- // Bare pointers are not supported for unranked memrefs because a
- // memref descriptor cannot be built just from a bare pointer.
- return Value();
- }
- Value desc =
- UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
- // An argument materialization must return a value of type
- // `resultType`, so insert a cast from the memref descriptor type
- // (!llvm.struct) to the original memref type.
- return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
- .getResult(0);
- });
- addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
- ValueRange inputs, Location loc) {
- Value desc;
- if (inputs.size() == 1) {
- // This is a bare pointer. We allow bare pointers only for function entry
- // blocks.
- BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front());
- if (!barePtr)
- return Value();
- Block *block = barePtr.getOwner();
- if (!block->isEntryBlock() ||
- !isa<FunctionOpInterface>(block->getParentOp()))
+ /*
+ // Argument materializations convert from the new block argument types
+ // (multiple SSA values that make up a memref descriptor) back to the
+ // original block argument type. The dialect conversion framework will then
+ // insert a target materialization from the original block argument type to
+ // a legal type.
+ addArgumentMaterialization([&](OpBuilder &builder,
+ UnrankedMemRefType resultType,
+ ValueRange inputs, Location loc) {
+ if (inputs.size() == 1) {
+ // Bare pointers are not supported for unranked memrefs because a
+ // memref descriptor cannot be built just from a bare pointer.
return Value();
- desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
- inputs[0]);
- } else {
- desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
- }
- // An argument materialization must return a value of type `resultType`,
- // so insert a cast from the memref descriptor type (!llvm.struct) to the
- // original memref type.
- return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
- .getResult(0);
- });
-
-*/
+ }
+ Value desc =
+ UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType,
+ inputs);
+ // An argument materialization must return a value of type
+ // `resultType`, so insert a cast from the memref descriptor type
+ // (!llvm.struct) to the original memref type.
+ return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
+ .getResult(0);
+ });
+ addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
+ ValueRange inputs, Location loc) {
+ Value desc;
+ if (inputs.size() == 1) {
+ // This is a bare pointer. We allow bare pointers only for function
+ entry
+ // blocks.
+ BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front());
+ if (!barePtr)
+ return Value();
+ Block *block = barePtr.getOwner();
+ if (!block->isEntryBlock() ||
+ !isa<FunctionOpInterface>(block->getParentOp()))
+ return Value();
+ desc = MemRefDescriptor::fromStaticShape(builder, loc, *this,
+ resultType, inputs[0]); } else { desc = MemRefDescriptor::pack(builder, loc,
+ *this, resultType, inputs);
+ }
+ // An argument materialization must return a value of type `resultType`,
+ // so insert a cast from the memref descriptor type (!llvm.struct) to the
+ // original memref type.
+ return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
+ .getResult(0);
+ });
+
+ */
// Add generic source and target materializations to handle cases where
// non-LLVM types persist after an LLVM conversion.
addSourceMaterialization([&](OpBuilder &builder, Type resultType,
ValueRange inputs, Location loc) {
- //if (inputs.size() != 1)
- // return Value();
+ // if (inputs.size() != 1)
+ // return Value();
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
});
addSourceMaterialization([&](OpBuilder &builder, MemRefType resultType,
ValueRange inputs, Location loc) {
- if (inputs.size()== 1 && isa<LLVM::LLVMStructType>(inputs.front().getType())) return Value();
+ if (inputs.size() == 1 &&
+ isa<LLVM::LLVMStructType>(inputs.front().getType()))
+ return Value();
Value desc;
- if (inputs.size() == 1 && isa<LLVM::LLVMPointerType>(inputs.front().getType())) {
+ if (inputs.size() == 1 &&
+ isa<LLVM::LLVMPointerType>(inputs.front().getType())) {
// This is a bare pointer. We allow bare pointers only for function entry
// blocks.
BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front());
@@ -229,10 +233,10 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
inputs[0]);
} else {
- //llvm::errs() << "pack elems: " << inputs.size() << "\n";
- //llvm::errs() << inputs[0] << "\n";
+ // llvm::errs() << "pack elems: " << inputs.size() << "\n";
+ // llvm::errs() << inputs[0] << "\n";
desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
- //llvm::errs() << "done packing\n";
+ // llvm::errs() << "done packing\n";
}
// An argument materialization must return a value of type `resultType`,
// so insert a cast from the memref descriptor type (!llvm.struct) to the
@@ -240,8 +244,9 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
.getResult(0);
});
- addSourceMaterialization([&](OpBuilder &builder, UnrankedMemRefType resultType,
- ValueRange inputs, Location loc) {
+ addSourceMaterialization([&](OpBuilder &builder,
+ UnrankedMemRefType resultType, ValueRange inputs,
+ Location loc) {
if (inputs.size() == 1) {
// Bare pointers are not supported for unranked memrefs because a
// memref descriptor cannot be built just from a bare pointer.
@@ -264,8 +269,8 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
.getResult(0);
});
addTargetMaterialization([&](OpBuilder &builder, Type resultType,
- ValueRange inputs,
- Location loc, Type originalType) -> Value {
+ ValueRange inputs, Location loc,
+ Type originalType) -> Value {
llvm::errs() << "TARGET MAT: -> " << resultType << "\n";
if (!originalType) {
llvm::errs() << " -- no orig\n";
@@ -275,8 +280,9 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
assert(isa<LLVM::LLVMStructType>(resultType) && "expected struct type");
if (inputs.size() == 1) {
Value input = inputs.front();
- if (auto castOp =input.getDefiningOp<UnrealizedConversionCastOp>()) {
- if (castOp.getInputs().size() == 1 && isa<LLVM::LLVMPointerType>(castOp.getInputs()[0].getType())) {
+ if (auto castOp = input.getDefiningOp<UnrealizedConversionCastOp>()) {
+ if (castOp.getInputs().size() == 1 &&
+ isa<LLVM::LLVMPointerType>(castOp.getInputs()[0].getType())) {
input = castOp.getInputs()[0];
}
}
@@ -290,23 +296,23 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
!isa<FunctionOpInterface>(block->getParentOp()))
return Value();
// Bare ptr
- return MemRefDescriptor::fromStaticShape(builder, loc, *this, memrefType,
- input);
+ return MemRefDescriptor::fromStaticShape(builder, loc, *this,
+ memrefType, input);
}
return MemRefDescriptor::pack(builder, loc, *this, memrefType, inputs);
}
if (auto memrefType = dyn_cast<UnrankedMemRefType>(originalType)) {
assert(isa<LLVM::LLVMStructType>(resultType) && "expected struct type");
if (inputs.size() == 1) {
- // Bare pointers are not supported for unranked memrefs because a
- // memref descriptor cannot be built just from a bare pointer.
- return Value();
+ // Bare pointers are not supported for unranked memrefs because a
+ // memref descriptor cannot be built just from a bare pointer.
+ return Value();
}
- return UnrankedMemRefDescriptor::pack(builder, loc, *this,
- memrefType, inputs);
+ return UnrankedMemRefDescriptor::pack(builder, loc, *this, memrefType,
+ inputs);
}
- return Value();
+ return Value();
});
// Integer memory spaces map to themselves.
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 9b9682148b..a40678f912 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -122,15 +122,16 @@ struct ConversionValueMapping {
if (it == mapping.end())
return Value();
const SmallVector<Value, 1> &repl = it->second;
- if (repl.size() != 1) return Value();
- return repl.front();
-/*
- if (!mapping.contains(from)) return Value();
- auto it = llvm::find(mapping, from);
- const SmallVector<Value, 1> &repl = it->second;
- if (repl.size() != 1) return Value();
+ if (repl.size() != 1)
+ return Value();
return repl.front();
- */
+ /*
+ if (!mapping.contains(from)) return Value();
+ auto it = llvm::find(mapping, from);
+ const SmallVector<Value, 1> &repl = it->second;
+ if (repl.size() != 1) return Value();
+ return repl.front();
+ */
}
/// Find the most recently mapped values for the given value. If the value is
@@ -1299,7 +1300,8 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
SmallVector<Value, 1> vals = mapping.lookupOrDefault(operand);
ValueRange castValues = buildUnresolvedMaterialization(
MaterializationKind::Target, computeInsertPoint(vals), operandLoc,
- /*inputs=*/vals, /*outputTypes=*/legalTypes, /*originalType=*/origType, currentTypeConverter);
+ /*inputs=*/vals, /*outputTypes=*/legalTypes, /*originalType=*/origType,
+ currentTypeConverter);
mapping.mapMaterialization(vals, castValues);
remapped.push_back(castValues);
@@ -1454,7 +1456,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
/// of input operands.
ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
- ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter) {
+ ValueRange inputs, TypeRange outputTypes, Type originalType,
+ const TypeConverter *converter) {
// Avoid materializing an unnecessary cast.
if (TypeRange(inputs) == outputTypes)
return inputs;
@@ -1465,7 +1468,8 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
auto convertOp =
builder.create<UnrealizedConversionCastOp>(loc, outputTypes, inputs);
- appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind, originalType);
+ appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
+ originalType);
return convertOp.getResults();
}
@@ -1686,7 +1690,8 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
});
impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from);
SmallVector<Value, 1> mapped = impl->mapping.lookupOrDefault(from);
- assert(mapped.size() == 1 && "replaceUsesOfBlockArgument is not supported for 1:N replacements");
+ assert(mapped.size() == 1 &&
+ "replaceUsesOfBlockArgument is not supported for 1:N replacements");
impl->mapping.map(mapped.front(), to);
}
@@ -2599,7 +2604,8 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
rewrite->getOriginalType());
break;
case MaterializationKind::Source:
- assert(op.getNumResults() == 1 && "*:N source materializations are not supported");
+ assert(op.getNumResults() == 1 &&
+ "*:N source materializations are not supported");
Value sourceMat = converter->materializeSourceConversion(
rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
if (sourceMat)
@@ -2617,7 +2623,8 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
InFlightDiagnostic diag = op->emitError()
<< "failed to legalize unresolved materialization "
"from ("
- << inputOperands.getTypes() << ") to (" << op.getResultTypes()
+ << inputOperands.getTypes() << ") to ("
+ << op.getResultTypes()
<< ") that remained live after conversion";
diag.attachNote(op->getUsers().begin()->getLoc())
<< "see existing live user here: " << *op->getUsers().begin();
@@ -2741,7 +2748,8 @@ void OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
for (Value originalValue : replacedValues) {
// If this value is directly replaced with a value of the same type,
// there is nothing to do.
- Value repl = rewriterImpl.mapping.lookupDirectSingleReplacement(originalValue);
+ Value repl =
+ rewriterImpl.mapping.lookupDirectSingleReplacement(originalValue);
if (repl && repl.getType() == originalValue.getType())
continue;
// If the type of this value changed and the value is still live, we need
@@ -2761,8 +2769,8 @@ void OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
Value castValue = rewriterImpl.buildUnresolvedMaterialization(
MaterializationKind::Source, computeInsertPoint(newValues),
originalValue.getLoc(),
- /*inputs=*/newValues, /*outputTypes=*/originalValue.getType(), /*originalType=*/Type(),
- converter)[0];
+ /*inputs=*/newValues, /*outputTypes=*/originalValue.getType(),
+ /*originalType=*/Type(), converter)[0];
rewriterImpl.mapping.mapMaterialization(newValues, {castValue});
llvm::append_range(inverseMapping[castValue], newValues);
}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 9154964465..9fef1e4819 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1076,7 +1076,8 @@ struct TestUpdateConsumerType : public ConversionPattern {
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
- llvm::errs() << "TestUpdateConsumerType operand: " << operands.front() << "\n";
+ llvm::errs() << "TestUpdateConsumerType operand: " << operands.front()
+ << "\n";
// Verify that the incoming operand has been successfully remapped to F64.
if (!operands[0].getType().isF64())
return failure();
|
7ec251b to
820769d
Compare
0c57792 to
7a7d81a
Compare
52ebfbc to
76ccdee
Compare
Apply suggestions from code review Co-authored-by: Markus Böck <[email protected]> address comments
do not build argument materializations anymore fix more tests Fix decompose call graph test
76ccdee to
ff12413
Compare
001453c to
f5ed959
Compare
Base automatically changed from
users/matthias-springer/replace_with_multiple
to
main
November 14, 2024 01:28
Member
Author
|
This PR was split into multiple PRs. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
No description provided.