-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[Local] Make combineAAMetadata() more principled #122091
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-llvm-transforms Author: Nikita Popov (nikic) ChangesThis moves combineAAMetadata() into Local and implements it via a new AAOnly flag, which will intersect only AA metadata and keep other known metadata. The existing KnownIDs list is dropped, because it is redundant with the switch in combineMetadata(), which already drops unknown metadata. I tried a few variants of this, and ultimately went with the AAOnly flag because this way we make an explicit choice for each metadata kind supported by combineMetadata(), and ignoring the flag gives you conservatively correct behavior. I checked that the memcpy tests still pass if we adjust the logic for MD_memprof/MD_callsite to drop the metadata instead of arbitrarily picking one. Fixes #121495. Full diff: https://github.com/llvm/llvm-project/pull/122091.diff 3 Files Affected:
diff --git a/llvm/include/llvm/Transforms/Utils/Local.h b/llvm/include/llvm/Transforms/Utils/Local.h
index 40c448593807bb..db064e1f41f023 100644
--- a/llvm/include/llvm/Transforms/Utils/Local.h
+++ b/llvm/include/llvm/Transforms/Utils/Local.h
@@ -412,19 +412,6 @@ Instruction *removeUnwindEdge(BasicBlock *BB, DomTreeUpdater *DTU = nullptr);
bool removeUnreachableBlocks(Function &F, DomTreeUpdater *DTU = nullptr,
MemorySSAUpdater *MSSAU = nullptr);
-/// DO NOT CALL EXTERNALLY.
-/// FIXME: https://github.com/llvm/llvm-project/issues/121495
-/// Once external callers of this function are removed, either inline into
-/// combineMetadataForCSE, or internalize and remove KnownIDs parameter.
-///
-/// Combine the metadata of two instructions so that K can replace J. Some
-/// metadata kinds can only be kept if K does not move, meaning it dominated
-/// J in the original IR.
-///
-/// Metadata not listed as known via KnownIDs is removed
-void combineMetadata(Instruction *K, const Instruction *J,
- ArrayRef<unsigned> KnownIDs, bool DoesKMove);
-
/// Combine the metadata of two instructions so that K can replace J. This
/// specifically handles the case of CSE-like transformations. Some
/// metadata can only be kept if K dominates J. For this to be correct,
@@ -434,6 +421,11 @@ void combineMetadata(Instruction *K, const Instruction *J,
void combineMetadataForCSE(Instruction *K, const Instruction *J,
bool DoesKMove);
+/// Combine metadata of two instructions, where instruction J is a memory
+/// access that has been merged into K. This will intersect alias-analysis
+/// metadata, while preserving other known metadata.
+void combineAAMetadata(Instruction *K, const Instruction *J);
+
/// Copy the metadata from the source instruction to the destination (the
/// replacement for the source instruction).
void copyMetadataForLoad(LoadInst &Dest, const LoadInst &Source);
diff --git a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
index 5f7cb92d239bc1..1de3219bc80429 100644
--- a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
+++ b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
@@ -341,21 +341,6 @@ static bool writtenBetween(MemorySSA *MSSA, BatchAAResults &AA,
return !MSSA->dominates(Clobber, Start);
}
-// Update AA metadata
-static void combineAAMetadata(Instruction *ReplInst, Instruction *I) {
- // FIXME: MD_tbaa_struct and MD_mem_parallel_loop_access should also be
- // handled here, but combineMetadata doesn't support them yet
- unsigned KnownIDs[] = {
- LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope,
- LLVMContext::MD_noalias, LLVMContext::MD_invariant_group,
- LLVMContext::MD_access_group, LLVMContext::MD_prof,
- LLVMContext::MD_memprof, LLVMContext::MD_callsite};
- // FIXME: https://github.com/llvm/llvm-project/issues/121495
- // Use custom AA metadata combining handling instead of combineMetadata, which
- // is meant for CSE and will drop any metadata not in the KnownIDs list.
- combineMetadata(ReplInst, I, KnownIDs, true);
-}
-
/// When scanning forward over instructions, we look for some other patterns to
/// fold away. In particular, this looks for stores to neighboring locations of
/// memory. If it sees enough consecutive ones, it attempts to merge them
diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp
index 1e4061cb0771e5..918b3a4225ed91 100644
--- a/llvm/lib/Transforms/Utils/Local.cpp
+++ b/llvm/lib/Transforms/Utils/Local.cpp
@@ -3308,13 +3308,11 @@ bool llvm::removeUnreachableBlocks(Function &F, DomTreeUpdater *DTU,
return Changed;
}
-// FIXME: https://github.com/llvm/llvm-project/issues/121495
-// Once external callers of this function are removed, either inline into
-// combineMetadataForCSE, or internalize and remove KnownIDs parameter.
-void llvm::combineMetadata(Instruction *K, const Instruction *J,
- ArrayRef<unsigned> KnownIDs, bool DoesKMove) {
+/// If AAOnly is set, only intersect alias analysis metadata and preserve other
+/// known metadata. Unknown metadata is always dropped.
+static void combineMetadata(Instruction *K, const Instruction *J,
+ bool DoesKMove, bool AAOnly = false) {
SmallVector<std::pair<unsigned, MDNode *>, 4> Metadata;
- K->dropUnknownNonDebugMetadata(KnownIDs);
K->getAllMetadataOtherThanDebugLoc(Metadata);
for (const auto &MD : Metadata) {
unsigned Kind = MD.first;
@@ -3323,16 +3321,13 @@ void llvm::combineMetadata(Instruction *K, const Instruction *J,
switch (Kind) {
default:
- // FIXME: https://github.com/llvm/llvm-project/issues/121495
- // Change to removing only explicitly listed other metadata, and assert
- // on unknown metadata, to avoid inadvertently dropping newly added
- // metadata types.
K->setMetadata(Kind, nullptr); // Remove unknown metadata
break;
case LLVMContext::MD_dbg:
llvm_unreachable("getAllMetadataOtherThanDebugLoc returned a MD_dbg");
case LLVMContext::MD_DIAssignID:
- K->mergeDIAssignID(J);
+ if (!AAOnly)
+ K->mergeDIAssignID(J);
break;
case LLVMContext::MD_tbaa:
if (DoesKMove)
@@ -3353,11 +3348,12 @@ void llvm::combineMetadata(Instruction *K, const Instruction *J,
intersectAccessGroups(K, J));
break;
case LLVMContext::MD_range:
- if (DoesKMove || !K->hasMetadata(LLVMContext::MD_noundef))
+ if (!AAOnly && (DoesKMove || !K->hasMetadata(LLVMContext::MD_noundef)))
K->setMetadata(Kind, MDNode::getMostGenericRange(JMD, KMD));
break;
case LLVMContext::MD_fpmath:
- K->setMetadata(Kind, MDNode::getMostGenericFPMath(JMD, KMD));
+ if (!AAOnly)
+ K->setMetadata(Kind, MDNode::getMostGenericFPMath(JMD, KMD));
break;
case LLVMContext::MD_invariant_load:
// If K moves, only set the !invariant.load if it is present in both
@@ -3366,7 +3362,7 @@ void llvm::combineMetadata(Instruction *K, const Instruction *J,
K->setMetadata(Kind, JMD);
break;
case LLVMContext::MD_nonnull:
- if (DoesKMove || !K->hasMetadata(LLVMContext::MD_noundef))
+ if (!AAOnly && (DoesKMove || !K->hasMetadata(LLVMContext::MD_noundef)))
K->setMetadata(Kind, JMD);
break;
case LLVMContext::MD_invariant_group:
@@ -3376,36 +3372,39 @@ void llvm::combineMetadata(Instruction *K, const Instruction *J,
// Combine MMRAs
break;
case LLVMContext::MD_align:
- if (DoesKMove || !K->hasMetadata(LLVMContext::MD_noundef))
+ if (!AAOnly && (DoesKMove || !K->hasMetadata(LLVMContext::MD_noundef)))
K->setMetadata(
Kind, MDNode::getMostGenericAlignmentOrDereferenceable(JMD, KMD));
break;
case LLVMContext::MD_dereferenceable:
case LLVMContext::MD_dereferenceable_or_null:
- if (DoesKMove)
+ if (!AAOnly && DoesKMove)
K->setMetadata(Kind,
MDNode::getMostGenericAlignmentOrDereferenceable(JMD, KMD));
break;
case LLVMContext::MD_memprof:
- K->setMetadata(Kind, MDNode::getMergedMemProfMetadata(KMD, JMD));
+ if (!AAOnly)
+ K->setMetadata(Kind, MDNode::getMergedMemProfMetadata(KMD, JMD));
break;
case LLVMContext::MD_callsite:
- K->setMetadata(Kind, MDNode::getMergedCallsiteMetadata(KMD, JMD));
+ if (!AAOnly)
+ K->setMetadata(Kind, MDNode::getMergedCallsiteMetadata(KMD, JMD));
break;
case LLVMContext::MD_preserve_access_index:
// Preserve !preserve.access.index in K.
break;
case LLVMContext::MD_noundef:
// If K does move, keep noundef if it is present in both instructions.
- if (DoesKMove)
+ if (!AAOnly && DoesKMove)
K->setMetadata(Kind, JMD);
break;
case LLVMContext::MD_nontemporal:
// Preserve !nontemporal if it is present on both instructions.
- K->setMetadata(Kind, JMD);
+ if (!AAOnly)
+ K->setMetadata(Kind, JMD);
break;
case LLVMContext::MD_prof:
- if (DoesKMove)
+ if (!AAOnly && DoesKMove)
K->setMetadata(Kind, MDNode::getMergedProfMetadata(KMD, JMD, K, J));
break;
case LLVMContext::MD_noalias_addrspace:
@@ -3437,28 +3436,12 @@ void llvm::combineMetadata(Instruction *K, const Instruction *J,
}
void llvm::combineMetadataForCSE(Instruction *K, const Instruction *J,
- bool KDominatesJ) {
- unsigned KnownIDs[] = {LLVMContext::MD_tbaa,
- LLVMContext::MD_alias_scope,
- LLVMContext::MD_noalias,
- LLVMContext::MD_range,
- LLVMContext::MD_fpmath,
- LLVMContext::MD_invariant_load,
- LLVMContext::MD_nonnull,
- LLVMContext::MD_invariant_group,
- LLVMContext::MD_align,
- LLVMContext::MD_dereferenceable,
- LLVMContext::MD_dereferenceable_or_null,
- LLVMContext::MD_access_group,
- LLVMContext::MD_preserve_access_index,
- LLVMContext::MD_prof,
- LLVMContext::MD_nontemporal,
- LLVMContext::MD_noundef,
- LLVMContext::MD_mmra,
- LLVMContext::MD_noalias_addrspace,
- LLVMContext::MD_memprof,
- LLVMContext::MD_callsite};
- combineMetadata(K, J, KnownIDs, KDominatesJ);
+ bool DoesKMove) {
+ combineMetadata(K, J, DoesKMove);
+}
+
+void llvm::combineAAMetadata(Instruction *K, const Instruction *J) {
+ combineMetadata(K, J, /*DoesKMove=*/true, /*AAOnly=*/true);
}
void llvm::copyMetadataForLoad(LoadInst &Dest, const LoadInst &Source) {
|
You can test this locally with the following command:git-clang-format --diff a5c3cbf7e0df23ca898e4f65e78531641fe4bf60 010d34dc567ad0a0a5a8bf95b47bd75afcfb2779 --extensions h,cpp -- llvm/include/llvm/Transforms/Utils/Local.h llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp llvm/lib/Transforms/Utils/Local.cppView the diff from clang-format here.diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp
index 2d6f6a3b23..ba62cc3492 100644
--- a/llvm/lib/Transforms/Utils/Local.cpp
+++ b/llvm/lib/Transforms/Utils/Local.cpp
@@ -3321,98 +3321,97 @@ static void combineMetadata(Instruction *K, const Instruction *J,
// TODO: Assert that this switch is exhaustive for fixed MD kinds.
switch (Kind) {
- default:
- K->setMetadata(Kind, nullptr); // Remove unknown metadata
- break;
- case LLVMContext::MD_dbg:
- llvm_unreachable("getAllMetadataOtherThanDebugLoc returned a MD_dbg");
- case LLVMContext::MD_DIAssignID:
- if (!AAOnly)
- K->mergeDIAssignID(J);
- break;
- case LLVMContext::MD_tbaa:
- if (DoesKMove)
- K->setMetadata(Kind, MDNode::getMostGenericTBAA(JMD, KMD));
- break;
- case LLVMContext::MD_alias_scope:
- if (DoesKMove)
- K->setMetadata(Kind, MDNode::getMostGenericAliasScope(JMD, KMD));
- break;
- case LLVMContext::MD_noalias:
- case LLVMContext::MD_mem_parallel_loop_access:
- if (DoesKMove)
- K->setMetadata(Kind, MDNode::intersect(JMD, KMD));
- break;
- case LLVMContext::MD_access_group:
- if (DoesKMove)
- K->setMetadata(LLVMContext::MD_access_group,
- intersectAccessGroups(K, J));
- break;
- case LLVMContext::MD_range:
- if (!AAOnly && (DoesKMove || !K->hasMetadata(LLVMContext::MD_noundef)))
- K->setMetadata(Kind, MDNode::getMostGenericRange(JMD, KMD));
- break;
- case LLVMContext::MD_fpmath:
- if (!AAOnly)
- K->setMetadata(Kind, MDNode::getMostGenericFPMath(JMD, KMD));
- break;
- case LLVMContext::MD_invariant_load:
- // If K moves, only set the !invariant.load if it is present in both
- // instructions.
- if (DoesKMove)
- K->setMetadata(Kind, JMD);
- break;
- case LLVMContext::MD_nonnull:
- if (!AAOnly && (DoesKMove || !K->hasMetadata(LLVMContext::MD_noundef)))
- K->setMetadata(Kind, JMD);
- break;
- case LLVMContext::MD_invariant_group:
- // Preserve !invariant.group in K.
- break;
- case LLVMContext::MD_mmra:
- // Combine MMRAs
- break;
- case LLVMContext::MD_align:
- if (!AAOnly && (DoesKMove || !K->hasMetadata(LLVMContext::MD_noundef)))
- K->setMetadata(
- Kind, MDNode::getMostGenericAlignmentOrDereferenceable(JMD, KMD));
- break;
- case LLVMContext::MD_dereferenceable:
- case LLVMContext::MD_dereferenceable_or_null:
- if (!AAOnly && DoesKMove)
- K->setMetadata(Kind,
- MDNode::getMostGenericAlignmentOrDereferenceable(JMD, KMD));
- break;
- case LLVMContext::MD_memprof:
- if (!AAOnly)
- K->setMetadata(Kind, MDNode::getMergedMemProfMetadata(KMD, JMD));
- break;
- case LLVMContext::MD_callsite:
- if (!AAOnly)
- K->setMetadata(Kind, MDNode::getMergedCallsiteMetadata(KMD, JMD));
- break;
- case LLVMContext::MD_preserve_access_index:
- // Preserve !preserve.access.index in K.
- break;
- case LLVMContext::MD_noundef:
- // If K does move, keep noundef if it is present in both instructions.
- if (!AAOnly && DoesKMove)
- K->setMetadata(Kind, JMD);
- break;
- case LLVMContext::MD_nontemporal:
- // Preserve !nontemporal if it is present on both instructions.
- if (!AAOnly)
- K->setMetadata(Kind, JMD);
- break;
- case LLVMContext::MD_prof:
- if (!AAOnly && DoesKMove)
- K->setMetadata(Kind, MDNode::getMergedProfMetadata(KMD, JMD, K, J));
- break;
- case LLVMContext::MD_noalias_addrspace:
- if (DoesKMove)
- K->setMetadata(Kind,
- MDNode::getMostGenericNoaliasAddrspace(JMD, KMD));
- break;
+ default:
+ K->setMetadata(Kind, nullptr); // Remove unknown metadata
+ break;
+ case LLVMContext::MD_dbg:
+ llvm_unreachable("getAllMetadataOtherThanDebugLoc returned a MD_dbg");
+ case LLVMContext::MD_DIAssignID:
+ if (!AAOnly)
+ K->mergeDIAssignID(J);
+ break;
+ case LLVMContext::MD_tbaa:
+ if (DoesKMove)
+ K->setMetadata(Kind, MDNode::getMostGenericTBAA(JMD, KMD));
+ break;
+ case LLVMContext::MD_alias_scope:
+ if (DoesKMove)
+ K->setMetadata(Kind, MDNode::getMostGenericAliasScope(JMD, KMD));
+ break;
+ case LLVMContext::MD_noalias:
+ case LLVMContext::MD_mem_parallel_loop_access:
+ if (DoesKMove)
+ K->setMetadata(Kind, MDNode::intersect(JMD, KMD));
+ break;
+ case LLVMContext::MD_access_group:
+ if (DoesKMove)
+ K->setMetadata(LLVMContext::MD_access_group,
+ intersectAccessGroups(K, J));
+ break;
+ case LLVMContext::MD_range:
+ if (!AAOnly && (DoesKMove || !K->hasMetadata(LLVMContext::MD_noundef)))
+ K->setMetadata(Kind, MDNode::getMostGenericRange(JMD, KMD));
+ break;
+ case LLVMContext::MD_fpmath:
+ if (!AAOnly)
+ K->setMetadata(Kind, MDNode::getMostGenericFPMath(JMD, KMD));
+ break;
+ case LLVMContext::MD_invariant_load:
+ // If K moves, only set the !invariant.load if it is present in both
+ // instructions.
+ if (DoesKMove)
+ K->setMetadata(Kind, JMD);
+ break;
+ case LLVMContext::MD_nonnull:
+ if (!AAOnly && (DoesKMove || !K->hasMetadata(LLVMContext::MD_noundef)))
+ K->setMetadata(Kind, JMD);
+ break;
+ case LLVMContext::MD_invariant_group:
+ // Preserve !invariant.group in K.
+ break;
+ case LLVMContext::MD_mmra:
+ // Combine MMRAs
+ break;
+ case LLVMContext::MD_align:
+ if (!AAOnly && (DoesKMove || !K->hasMetadata(LLVMContext::MD_noundef)))
+ K->setMetadata(
+ Kind, MDNode::getMostGenericAlignmentOrDereferenceable(JMD, KMD));
+ break;
+ case LLVMContext::MD_dereferenceable:
+ case LLVMContext::MD_dereferenceable_or_null:
+ if (!AAOnly && DoesKMove)
+ K->setMetadata(
+ Kind, MDNode::getMostGenericAlignmentOrDereferenceable(JMD, KMD));
+ break;
+ case LLVMContext::MD_memprof:
+ if (!AAOnly)
+ K->setMetadata(Kind, MDNode::getMergedMemProfMetadata(KMD, JMD));
+ break;
+ case LLVMContext::MD_callsite:
+ if (!AAOnly)
+ K->setMetadata(Kind, MDNode::getMergedCallsiteMetadata(KMD, JMD));
+ break;
+ case LLVMContext::MD_preserve_access_index:
+ // Preserve !preserve.access.index in K.
+ break;
+ case LLVMContext::MD_noundef:
+ // If K does move, keep noundef if it is present in both instructions.
+ if (!AAOnly && DoesKMove)
+ K->setMetadata(Kind, JMD);
+ break;
+ case LLVMContext::MD_nontemporal:
+ // Preserve !nontemporal if it is present on both instructions.
+ if (!AAOnly)
+ K->setMetadata(Kind, JMD);
+ break;
+ case LLVMContext::MD_prof:
+ if (!AAOnly && DoesKMove)
+ K->setMetadata(Kind, MDNode::getMergedProfMetadata(KMD, JMD, K, J));
+ break;
+ case LLVMContext::MD_noalias_addrspace:
+ if (DoesKMove)
+ K->setMetadata(Kind, MDNode::getMostGenericNoaliasAddrspace(JMD, KMD));
+ break;
}
}
// Set !invariant.group from J if J has it. If both instructions have it
|
dianqk
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, LGTM. Thanks!
teresajohnson
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm but suggest a FIXME as noted below. Thanks!
| // Change to removing only explicitly listed other metadata, and assert | ||
| // on unknown metadata, to avoid inadvertently dropping newly added | ||
| // metadata types. | ||
| K->setMetadata(Kind, nullptr); // Remove unknown metadata |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it still makes sense to have a FIXME here to assert on unknown metadata. Otherwise we still end up in a situation where new metadata gets silently dropped, until someone digs into a case where it is missing, and that can be hard to find.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can't assert on unknown metadata, because metadata is an extension mechanism. You can add , !foobar !{} on an instruction with some kind of custom meaning. This metadata has to be dropped in combineMetadata() to be conservatively correct, but we can't assert on encountering it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about doing this at least for the fixed MD types specified in llvm/include/llvm/IR/FixedMetadataKinds.def? If someone adds an MD type there, they should be flagged to update this handling.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's possible in principle. We'd need an extra value to indicate what the last fixed MD kind is and make it a named enum so that we can switch over the enum and get a -Werror on uncovered enum case. Just asserting it would be too much of a liability, as people wouldn't be aware they have to update this code.
I've left a TODO for now.
fhahn
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for cleaning this up
This moves combineAAMetadata() into Local and implements it via a new AAOnly flag, which will intersect only AA metadata and keep other known metadata. The existing KnownIDs list is dropped, because it is redundant with the switch in combineMetadata(), which already drops unknown metadata. I checked that the memcpy tests still pass if we adjust the logic for MD_memprof/MD_callsite to drop the metadata instead of arbitrarily picking one. Fixes llvm#121495.
bc23709 to
010d34d
Compare
This moves combineAAMetadata() into Local and implements it via a new AAOnly flag, which will intersect only AA metadata and keep other known metadata.
The existing KnownIDs list is dropped, because it is redundant with the switch in combineMetadata(), which already drops unknown metadata.
I tried a few variants of this, and ultimately went with the AAOnly flag because this way we make an explicit choice for each metadata kind supported by combineMetadata(), and ignoring the flag gives you conservatively correct behavior.
I checked that the memcpy tests still pass if we adjust the logic for MD_memprof/MD_callsite to drop the metadata instead of arbitrarily picking one.
Fixes #121495.