-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][Transforms] Handle attributes in 1-to-many function conversions #162579
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Function attributes are associated to the argument index, which may change during a 1-to-many type conversion but is currently not being updated by the conversion pattern. The results is that attributes get applied to the wrong arguments after the conversion. Happy to add a lit test, if you can suggest a good place for it.
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir Author: None (peterbell10) ChangesFunction attributes are associated to the argument index, which may change during a 1-to-many type conversion but is currently not being updated by the conversion pattern. The results is that attributes get applied to the wrong arguments after the conversion. Happy to add a lit test, if you can suggest a good place for it. Full diff: https://github.com/llvm/llvm-project/pull/162579.diff 1 Files Affected:
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 3a23bbfd70eac..ec8d0971f2bbe 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -3789,6 +3789,27 @@ TypeConverter::convertTypeAttribute(Type type, Attribute attr) const {
// FunctionOpInterfaceSignatureConversion
//===----------------------------------------------------------------------===//
+static SmallVector<Attribute>
+convertFuncOpAttrs(FunctionOpInterface funcOp,
+ TypeConverter::SignatureConversion &sigConv,
+ FunctionType newType) {
+ if (newType.getNumInputs() == funcOp.getNumArguments()) {
+ return {};
+ }
+ ArrayAttr allArgAttrs = funcOp.getAllArgAttrs();
+ if (!allArgAttrs)
+ return {};
+
+ SmallVector<Attribute> newAttrs(newType.getNumInputs());
+ for (auto i : llvm::seq(allArgAttrs.size())) {
+ auto mapping = sigConv.getInputMapping(i);
+ assert(mapping.has_value());
+ auto outIdx = mapping->inputNo;
+ newAttrs[outIdx] = allArgAttrs[i];
+ }
+ return newAttrs;
+}
+
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
const TypeConverter &typeConverter,
ConversionPatternRewriter &rewriter) {
@@ -3809,7 +3830,16 @@ static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
auto newType = FunctionType::get(rewriter.getContext(),
result.getConvertedTypes(), newResults);
- rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
+ // If using 1-to-n type conversion, we must re-map argument attributes
+ // to the corresponding new argument index.
+ auto newArgAttrs = convertFuncOpAttrs(funcOp, result, newType);
+
+ rewriter.modifyOpInPlace(funcOp, [&] {
+ funcOp.setType(newType);
+ if (!newArgAttrs.empty()) {
+ funcOp.setAllArgAttrs(newArgAttrs);
+ }
+ });
return success();
}
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like to get a few more opinions on this.
So far, we refrained from propagating discardable attributes in patterns. This PR is going in the opposite direction. At the same time, when there's a 1:N conversion, there's an even higher chance that the attributes are semantically incorrect at the current position. Simply dropping all discardable attributes may break a lot of code.
One option I could live with: Add an optional lambda to the pattern, that allows downstream projects to control how attributes are converted / propagated. Basically, convertFuncOpAttrs
would be inside of a lambda in your downstream project.
Any thought?
for (auto i : llvm::seq(allArgAttrs.size())) { | ||
auto mapping = sigConv.getInputMapping(i); | ||
assert(mapping.has_value()); | ||
auto outIdx = mapping->inputNo; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if an argument is mapped to multiple block arguments? (I.e., mapping.size > 1
.)
Function attributes are associated to the argument index, which may change during a 1-to-many type conversion but is currently not being updated by the conversion pattern.
The results is that attributes get applied to the wrong arguments after the conversion.
Happy to add a lit test, if you can suggest a good place for it.