Skip to content

Commit 445dd85

Browse files
convert HIP struct type vector to llvm vector type (#416) (#490)
2 parents cf5a4bf + 87ad04d commit 445dd85

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

llvm/lib/Transforms/Scalar/SROA.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
#include "llvm/Transforms/Scalar.h"
8484
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
8585
#include "llvm/Transforms/Utils/Local.h"
86+
#include "llvm/TargetParser/Triple.h"
8687
#include "llvm/Transforms/Utils/PromoteMemToReg.h"
8788
#include "llvm/Transforms/Utils/SSAUpdater.h"
8889
#include <algorithm>
@@ -5242,6 +5243,34 @@ AllocaInst *SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS,
52425243
// FIXME: We might want to defer PHI speculation until after here.
52435244
// FIXME: return nullptr;
52445245
} else {
5246+
// AMDGPU: If the target is AMDGPU and the chosen SliceTy is a HIP vector
5247+
// struct of 2 or 4 identical elements, canonicalize it to an IR vector.
5248+
// This helps SROA treat it as a single value and unlock vector ld/st.
5249+
// We pattern-match struct names starting with "struct.HIP_vector".
5250+
if (Function *F = AI.getFunction()) {
5251+
Triple TT(F->getParent()->getTargetTriple());
5252+
if (TT.isAMDGPU()) {
5253+
if (auto *STy = dyn_cast<StructType>(SliceTy)) {
5254+
StringRef Name = STy->hasName() ? STy->getName() : StringRef();
5255+
if (Name.starts_with("struct.HIP_vector")) {
5256+
unsigned NumElts = STy->getNumElements();
5257+
if ((NumElts == 2 || NumElts == 4) && NumElts > 0) {
5258+
Type *EltTy = STy->getElementType(0);
5259+
bool AllSame = true;
5260+
for (unsigned I = 1; I < NumElts; ++I)
5261+
if (STy->getElementType(I) != EltTy) {
5262+
AllSame = false;
5263+
break;
5264+
}
5265+
if (AllSame && VectorType::isValidElementType(EltTy)) {
5266+
SliceTy = FixedVectorType::get(EltTy, NumElts);
5267+
}
5268+
}
5269+
}
5270+
}
5271+
}
5272+
}
5273+
52455274
// Make sure the alignment is compatible with P.beginOffset().
52465275
const Align Alignment = commonAlignment(AI.getAlign(), P.beginOffset());
52475276
// If we will get at least this much alignment from the type alone, leave

0 commit comments

Comments
 (0)