Skip to content

Commit 16afa09

Browse files
committed
[VFABI] Add support for vector functions that return struct types
This patch updates the `VFABIDemangler` to support vector functions that return struct types. For example, a vector variant of `sincos` that returns a vector of sine values and a vector of cosine values within a struct. This patch also adds some helpers in `llvm/IR/CallWideningUtils.h` for widening call return types. Some of these are used in the `VFABIDemangler`, and others will be used in subsequent patches, so this patch simply adds tests for them.
1 parent 3dbff90 commit 16afa09

File tree

9 files changed

+379
-20
lines changed

9 files changed

+379
-20
lines changed

llvm/include/llvm/Analysis/VectorUtils.h

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "llvm/Analysis/LoopAccessAnalysis.h"
1919
#include "llvm/IR/Module.h"
2020
#include "llvm/IR/VFABIDemangler.h"
21+
#include "llvm/IR/VectorUtils.h"
2122
#include "llvm/Support/CheckedArithmetic.h"
2223

2324
namespace llvm {
@@ -127,19 +128,6 @@ namespace Intrinsic {
127128
typedef unsigned ID;
128129
}
129130

130-
/// A helper function for converting Scalar types to vector types. If
131-
/// the incoming type is void, we return void. If the EC represents a
132-
/// scalar, we return the scalar type.
133-
inline Type *ToVectorTy(Type *Scalar, ElementCount EC) {
134-
if (Scalar->isVoidTy() || Scalar->isMetadataTy() || EC.isScalar())
135-
return Scalar;
136-
return VectorType::get(Scalar, EC);
137-
}
138-
139-
inline Type *ToVectorTy(Type *Scalar, unsigned VF) {
140-
return ToVectorTy(Scalar, ElementCount::getFixed(VF));
141-
}
142-
143131
/// Identify if the intrinsic is trivially vectorizable.
144132
/// This method returns true if the intrinsic's argument types are all scalars
145133
/// for the scalar form of the intrinsic and all vectors (or scalars handled by
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
//===---- CallWideningUtils.h - Utils for widening scalar to vector calls --==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef LLVM_IR_CALLWIDENINGUTILS_H
10+
#define LLVM_IR_CALLWIDENINGUTILS_H
11+
12+
#include "llvm/IR/DerivedTypes.h"
13+
14+
namespace llvm {
15+
16+
/// A helper for converting to wider (vector) types. For scalar types, this is
17+
/// equivalent to calling `ToVectorTy`. For struct types, this returns a new
18+
/// struct where each element type has been widened to a vector type. Note: Only
19+
/// unpacked literal struct types are supported.
20+
Type *ToWideTy(Type *Ty, ElementCount EC);
21+
22+
/// A helper for converting wide types to narrow (non-vector) types. For vector
23+
/// types, this is equivalent to calling .getScalarType(). For struct types,
24+
/// this returns a new struct where each element type has been converted to a
25+
/// scalar type. Note: Only unpacked literal struct types are supported.
26+
Type *ToNarrowTy(Type *Ty);
27+
28+
/// Returns the types contained in `Ty`. For struct types, it returns the
29+
/// elements, all other types are returned directly.
30+
SmallVector<Type *, 2> getContainedTypes(Type *Ty);
31+
32+
/// Returns true if `Ty` is a vector type or a struct of vector types where all
33+
/// vector types share the same VF.
34+
bool isWideTy(Type *Ty);
35+
36+
/// Returns the vectorization factor for a widened type.
37+
inline ElementCount getWideTypeVF(Type *Ty) {
38+
assert(isWideTy(Ty) && "expected widened type");
39+
return cast<VectorType>(getContainedTypes(Ty).front())->getElementCount();
40+
}
41+
42+
} // namespace llvm
43+
44+
#endif

llvm/include/llvm/IR/VectorUtils.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
//===----------- VectorUtils.h - Vector type utility functions -*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef LLVM_IR_VECTORUTILS_H
10+
#define LLVM_IR_VECTORUTILS_H
11+
12+
#include "llvm/ADT/SmallVector.h"
13+
#include "llvm/IR/DerivedTypes.h"
14+
15+
namespace llvm {
16+
17+
/// A helper function for converting Scalar types to vector types. If
18+
/// the incoming type is void, we return void. If the EC represents a
19+
/// scalar, we return the scalar type.
20+
inline Type *ToVectorTy(Type *Scalar, ElementCount EC) {
21+
if (Scalar->isVoidTy() || Scalar->isMetadataTy() || EC.isScalar())
22+
return Scalar;
23+
return VectorType::get(Scalar, EC);
24+
}
25+
26+
inline Type *ToVectorTy(Type *Scalar, unsigned VF) {
27+
return ToVectorTy(Scalar, ElementCount::getFixed(VF));
28+
}
29+
30+
} // namespace llvm
31+
32+
#endif

llvm/lib/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ add_llvm_component_library(LLVMCore
66
AutoUpgrade.cpp
77
BasicBlock.cpp
88
BuiltinGCs.cpp
9+
CallWideningUtils.cpp
910
Comdat.cpp
1011
ConstantFold.cpp
1112
ConstantFPRange.cpp

llvm/lib/IR/CallWideningUtils.cpp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
//===----------- VectorUtils.cpp - Vector type utility functions ----------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "llvm/IR/CallWideningUtils.h"
10+
#include "llvm/ADT/SmallVectorExtras.h"
11+
#include "llvm/IR/VectorUtils.h"
12+
13+
using namespace llvm;
14+
15+
/// A helper for converting to wider (vector) types. For scalar types, this is
16+
/// equivalent to calling `ToVectorTy`. For struct types, this returns a new
17+
/// struct where each element type has been widened to a vector type. Note: Only
18+
/// unpacked literal struct types are supported.
19+
Type *llvm::ToWideTy(Type *Ty, ElementCount EC) {
20+
if (EC.isScalar())
21+
return Ty;
22+
auto *StructTy = dyn_cast<StructType>(Ty);
23+
if (!StructTy)
24+
return ToVectorTy(Ty, EC);
25+
assert(StructTy->isLiteral() && !StructTy->isPacked() &&
26+
"expected unpacked struct literal");
27+
return StructType::get(
28+
Ty->getContext(),
29+
map_to_vector(StructTy->elements(), [&](Type *ElTy) -> Type * {
30+
return VectorType::get(ElTy, EC);
31+
}));
32+
}
33+
34+
/// A helper for converting wide types to narrow (non-vector) types. For vector
35+
/// types, this is equivalent to calling .getScalarType(). For struct types,
36+
/// this returns a new struct where each element type has been converted to a
37+
/// scalar type. Note: Only unpacked literal struct types are supported.
38+
Type *llvm::ToNarrowTy(Type *Ty) {
39+
auto *StructTy = dyn_cast<StructType>(Ty);
40+
if (!StructTy)
41+
return Ty->getScalarType();
42+
assert(StructTy->isLiteral() && !StructTy->isPacked() &&
43+
"expected unpacked struct literal");
44+
return StructType::get(
45+
Ty->getContext(),
46+
map_to_vector(StructTy->elements(), [](Type *ElTy) -> Type * {
47+
return ElTy->getScalarType();
48+
}));
49+
}
50+
51+
/// Returns the types contained in `Ty`. For struct types, it returns the
52+
/// elements, all other types are returned directly.
53+
SmallVector<Type *, 2> llvm::getContainedTypes(Type *Ty) {
54+
auto *StructTy = dyn_cast<StructType>(Ty);
55+
if (StructTy)
56+
return to_vector<2>(StructTy->elements());
57+
return {Ty};
58+
}
59+
60+
/// Returns true if `Ty` is a vector type or a struct of vector types where all
61+
/// vector types share the same VF.
62+
bool llvm::isWideTy(Type *Ty) {
63+
auto *StructTy = dyn_cast<StructType>(Ty);
64+
if (StructTy && (!StructTy->isLiteral() || StructTy->isPacked()))
65+
return false;
66+
auto ContainedTys = getContainedTypes(Ty);
67+
if (ContainedTys.empty() || !ContainedTys.front()->isVectorTy())
68+
return false;
69+
ElementCount VF = cast<VectorType>(ContainedTys.front())->getElementCount();
70+
return all_of(ContainedTys, [&](Type *Ty) {
71+
return Ty->isVectorTy() && cast<VectorType>(Ty)->getElementCount() == VF;
72+
});
73+
}

llvm/lib/IR/VFABIDemangler.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "llvm/ADT/SetVector.h"
1111
#include "llvm/ADT/SmallString.h"
1212
#include "llvm/ADT/StringSwitch.h"
13+
#include "llvm/IR/CallWideningUtils.h"
1314
#include "llvm/IR/Module.h"
1415
#include "llvm/Support/Debug.h"
1516
#include "llvm/Support/raw_ostream.h"
@@ -346,12 +347,15 @@ getScalableECFromSignature(const FunctionType *Signature, const VFISAKind ISA,
346347
// Also check the return type if not void.
347348
Type *RetTy = Signature->getReturnType();
348349
if (!RetTy->isVoidTy()) {
349-
std::optional<ElementCount> ReturnEC = getElementCountForTy(ISA, RetTy);
350-
// If we have an unknown scalar element type we can't find a reasonable VF.
351-
if (!ReturnEC)
352-
return std::nullopt;
353-
if (ElementCount::isKnownLT(*ReturnEC, MinEC))
354-
MinEC = *ReturnEC;
350+
for (Type *RetTy : getContainedTypes(RetTy)) {
351+
std::optional<ElementCount> ReturnEC = getElementCountForTy(ISA, RetTy);
352+
// If we have an unknown scalar element type we can't find a reasonable
353+
// VF.
354+
if (!ReturnEC)
355+
return std::nullopt;
356+
if (ElementCount::isKnownLT(*ReturnEC, MinEC))
357+
MinEC = *ReturnEC;
358+
}
355359
}
356360

357361
// The SVE Vector function call ABI bases the VF on the widest element types
@@ -566,7 +570,7 @@ FunctionType *VFABI::createFunctionType(const VFInfo &Info,
566570

567571
auto *RetTy = ScalarFTy->getReturnType();
568572
if (!RetTy->isVoidTy())
569-
RetTy = VectorType::get(RetTy, VF);
573+
RetTy = ToWideTy(RetTy, VF);
570574
return FunctionType::get(RetTy, VecTypes, false);
571575
}
572576

llvm/unittests/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ add_llvm_unittest(IRTests
1515
AttributesTest.cpp
1616
BasicBlockTest.cpp
1717
BasicBlockDbgInfoTest.cpp
18+
CallWideningUtilsTest.cpp
1819
CFGBuilder.cpp
1920
ConstantFPRangeTest.cpp
2021
ConstantRangeTest.cpp
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
//===------- CallWideningUtilsTest.cpp - Call widening utils tests --------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "llvm/IR/CallWideningUtils.h"
10+
#include "llvm/IR/DerivedTypes.h"
11+
#include "llvm/IR/LLVMContext.h"
12+
#include "gtest/gtest.h"
13+
14+
using namespace llvm;
15+
16+
namespace {
17+
18+
class CallWideningUtilsTest : public ::testing::Test {};
19+
20+
TEST(CallWideningUtilsTest, TestToWideTy) {
21+
LLVMContext C;
22+
23+
Type *ITy = Type::getInt32Ty(C);
24+
Type *FTy = Type::getFloatTy(C);
25+
Type *HomogeneousStructTy = StructType::get(FTy, FTy, FTy);
26+
Type *MixedStructTy = StructType::get(FTy, ITy);
27+
Type *VoidTy = Type::getVoidTy(C);
28+
29+
for (ElementCount VF :
30+
{ElementCount::getFixed(4), ElementCount::getScalable(2)}) {
31+
Type *IntVec = ToWideTy(ITy, VF);
32+
EXPECT_TRUE(isa<VectorType>(IntVec));
33+
EXPECT_EQ(IntVec, VectorType::get(ITy, VF));
34+
35+
Type *FloatVec = ToWideTy(FTy, VF);
36+
EXPECT_TRUE(isa<VectorType>(FloatVec));
37+
EXPECT_EQ(FloatVec, VectorType::get(FTy, VF));
38+
39+
Type *WideHomogeneousStructTy = ToWideTy(HomogeneousStructTy, VF);
40+
EXPECT_TRUE(isa<StructType>(WideHomogeneousStructTy));
41+
EXPECT_TRUE(
42+
cast<StructType>(WideHomogeneousStructTy)->containsHomogeneousTypes());
43+
EXPECT_TRUE(cast<StructType>(WideHomogeneousStructTy)->getNumElements() ==
44+
3);
45+
EXPECT_TRUE(cast<StructType>(WideHomogeneousStructTy)->getElementType(0) ==
46+
VectorType::get(FTy, VF));
47+
48+
Type *WideMixedStructTy = ToWideTy(MixedStructTy, VF);
49+
EXPECT_TRUE(isa<StructType>(WideMixedStructTy));
50+
EXPECT_TRUE(cast<StructType>(WideMixedStructTy)->getNumElements() == 2);
51+
EXPECT_TRUE(cast<StructType>(WideMixedStructTy)->getElementType(0) ==
52+
VectorType::get(FTy, VF));
53+
EXPECT_TRUE(cast<StructType>(WideMixedStructTy)->getElementType(1) ==
54+
VectorType::get(ITy, VF));
55+
56+
EXPECT_EQ(ToWideTy(VoidTy, VF), VoidTy);
57+
}
58+
59+
ElementCount ScalarVF = ElementCount::getFixed(1);
60+
for (Type *Ty : {ITy, FTy, HomogeneousStructTy, MixedStructTy, VoidTy}) {
61+
EXPECT_EQ(ToWideTy(Ty, ScalarVF), Ty);
62+
}
63+
}
64+
65+
TEST(CallWideningUtilsTest, TestToNarrowTy) {
66+
LLVMContext C;
67+
68+
Type *ITy = Type::getInt32Ty(C);
69+
Type *FTy = Type::getFloatTy(C);
70+
Type *HomogeneousStructTy = StructType::get(FTy, FTy, FTy);
71+
Type *MixedStructTy = StructType::get(FTy, ITy);
72+
Type *VoidTy = Type::getVoidTy(C);
73+
74+
for (ElementCount VF : {ElementCount::getFixed(1), ElementCount::getFixed(4),
75+
ElementCount::getScalable(2)}) {
76+
for (Type *Ty : {ITy, FTy, HomogeneousStructTy, MixedStructTy, VoidTy}) {
77+
// ToNarrowTy should be the inverse of ToWideTy.
78+
EXPECT_EQ(ToNarrowTy(ToWideTy(Ty, VF)), Ty);
79+
};
80+
}
81+
}
82+
83+
TEST(CallWideningUtilsTest, TestGetContainedTypes) {
84+
LLVMContext C;
85+
86+
Type *ITy = Type::getInt32Ty(C);
87+
Type *FTy = Type::getFloatTy(C);
88+
Type *HomogeneousStructTy = StructType::get(FTy, FTy, FTy);
89+
Type *MixedStructTy = StructType::get(FTy, ITy);
90+
Type *VoidTy = Type::getVoidTy(C);
91+
92+
EXPECT_EQ(getContainedTypes(ITy), SmallVector<Type *>({ITy}));
93+
EXPECT_EQ(getContainedTypes(FTy), SmallVector<Type *>({FTy}));
94+
EXPECT_EQ(getContainedTypes(VoidTy), SmallVector<Type *>({VoidTy}));
95+
EXPECT_EQ(getContainedTypes(HomogeneousStructTy),
96+
SmallVector<Type *>({FTy, FTy, FTy}));
97+
EXPECT_EQ(getContainedTypes(MixedStructTy), SmallVector<Type *>({FTy, ITy}));
98+
}
99+
100+
TEST(CallWideningUtilsTest, TestIsWideTy) {
101+
LLVMContext C;
102+
103+
Type *ITy = Type::getInt32Ty(C);
104+
Type *FTy = Type::getFloatTy(C);
105+
Type *NarrowStruct = StructType::get(FTy, ITy);
106+
Type *VoidTy = Type::getVoidTy(C);
107+
108+
EXPECT_FALSE(isWideTy(ITy));
109+
EXPECT_FALSE(isWideTy(NarrowStruct));
110+
EXPECT_FALSE(isWideTy(VoidTy));
111+
112+
ElementCount VF = ElementCount::getFixed(4);
113+
EXPECT_TRUE(isWideTy(ToWideTy(ITy, VF)));
114+
EXPECT_TRUE(isWideTy(ToWideTy(NarrowStruct, VF)));
115+
116+
Type *MixedVFStruct =
117+
StructType::get(VectorType::get(ITy, ElementCount::getFixed(2)),
118+
VectorType::get(ITy, ElementCount::getFixed(4)));
119+
EXPECT_FALSE(isWideTy(MixedVFStruct));
120+
121+
// Currently only literals types are considered wide.
122+
Type *NamedWideStruct = StructType::create("Named", VectorType::get(ITy, VF),
123+
VectorType::get(ITy, VF));
124+
EXPECT_FALSE(isWideTy(NamedWideStruct));
125+
126+
// Currently only unpacked types are considered wide.
127+
Type *PackedWideStruct = StructType::get(
128+
C, ArrayRef<Type *>{VectorType::get(ITy, VF), VectorType::get(ITy, VF)},
129+
/*isPacked=*/true);
130+
EXPECT_FALSE(isWideTy(PackedWideStruct));
131+
}
132+
133+
TEST(CallWideningUtilsTest, TestGetWideTypeVF) {
134+
LLVMContext C;
135+
136+
Type *ITy = Type::getInt32Ty(C);
137+
Type *FTy = Type::getFloatTy(C);
138+
Type *HomogeneousStructTy = StructType::get(FTy, FTy, FTy);
139+
Type *MixedStructTy = StructType::get(FTy, ITy);
140+
141+
for (ElementCount VF :
142+
{ElementCount::getFixed(4), ElementCount::getScalable(2)}) {
143+
for (Type *Ty : {ITy, FTy, HomogeneousStructTy, MixedStructTy}) {
144+
EXPECT_EQ(getWideTypeVF(ToWideTy(Ty, VF)), VF);
145+
};
146+
}
147+
}
148+
149+
} // namespace

0 commit comments

Comments
 (0)