Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 59 additions & 83 deletions clang/lib/CodeGen/Targets/RISCV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -441,98 +441,74 @@ bool RISCVABIInfo::detectVLSCCEligibleStruct(QualType Ty, unsigned ABIVLen,
// __attribute__((vector_size(64))) int d;
// }
//
// Struct of 1 fixed-length vector is passed as a scalable vector.
// Struct of >1 fixed-length vectors are passed as vector tuple.
// Struct of 1 array of fixed-length vectors is passed as a scalable vector.
// Otherwise, pass the struct indirectly.

if (llvm::StructType *STy = dyn_cast<llvm::StructType>(CGT.ConvertType(Ty))) {
unsigned NumElts = STy->getStructNumElements();
if (NumElts > 8)
return false;
// 1. Struct of 1 fixed-length vector is passed as a scalable vector.
// 2. Struct of >1 fixed-length vectors are passed as vector tuple.
// 3. Struct of an array with 1 element of fixed-length vectors is passed as a
// scalable vector.
// 4. Struct of an array with >1 elements of fixed-length vectors is passed as
// vector tuple.
// 5. Otherwise, pass the struct indirectly.

llvm::StructType *STy = dyn_cast<llvm::StructType>(CGT.ConvertType(Ty));
if (!STy)
return false;

auto *FirstEltTy = STy->getElementType(0);
if (!STy->containsHomogeneousTypes())
return false;
unsigned NumElts = STy->getStructNumElements();
if (NumElts > 8)
return false;

// Check structure of fixed-length vectors and turn them into vector tuple
// type if legal.
if (auto *FixedVecTy = dyn_cast<llvm::FixedVectorType>(FirstEltTy)) {
if (NumElts == 1) {
// Handle single fixed-length vector.
VLSType = llvm::ScalableVectorType::get(
FixedVecTy->getElementType(),
llvm::divideCeil(FixedVecTy->getNumElements() *
llvm::RISCV::RVVBitsPerBlock,
ABIVLen));
// Check registers needed <= 8.
return llvm::divideCeil(
FixedVecTy->getNumElements() *
FixedVecTy->getElementType()->getScalarSizeInBits(),
ABIVLen) <= 8;
}
// LMUL
// = fixed-length vector size / ABIVLen
// = 8 * I8EltCount / RVVBitsPerBlock
// =>
// I8EltCount
// = (fixed-length vector size * RVVBitsPerBlock) / (ABIVLen * 8)
unsigned I8EltCount = llvm::divideCeil(
FixedVecTy->getNumElements() *
FixedVecTy->getElementType()->getScalarSizeInBits() *
llvm::RISCV::RVVBitsPerBlock,
ABIVLen * 8);
VLSType = llvm::TargetExtType::get(
getVMContext(), "riscv.vector.tuple",
llvm::ScalableVectorType::get(llvm::Type::getInt8Ty(getVMContext()),
I8EltCount),
NumElts);
// Check registers needed <= 8.
return NumElts *
llvm::divideCeil(
FixedVecTy->getNumElements() *
FixedVecTy->getElementType()->getScalarSizeInBits(),
ABIVLen) <=
8;
}
auto *FirstEltTy = STy->getElementType(0);
if (!STy->containsHomogeneousTypes())
return false;

// If elements are not fixed-length vectors, it should be an array.
if (auto *ArrayTy = dyn_cast<llvm::ArrayType>(FirstEltTy)) {
// Only struct of single array is accepted
if (NumElts != 1)
return false;
FirstEltTy = ArrayTy->getArrayElementType();
NumElts = ArrayTy->getNumElements();
}

// Check array of fixed-length vector and turn it into scalable vector type
// if legal.
if (auto *ArrTy = dyn_cast<llvm::ArrayType>(FirstEltTy)) {
unsigned NumArrElt = ArrTy->getNumElements();
if (NumArrElt > 8)
return false;
auto *FixedVecTy = dyn_cast<llvm::FixedVectorType>(FirstEltTy);
if (!FixedVecTy)
return false;

auto *ArrEltTy = dyn_cast<llvm::FixedVectorType>(ArrTy->getElementType());
if (!ArrEltTy)
return false;
// Check registers needed <= 8.
if (NumElts * llvm::divideCeil(
FixedVecTy->getNumElements() *
FixedVecTy->getElementType()->getScalarSizeInBits(),
ABIVLen) >
8)
return false;

// LMUL
// = NumArrElt * fixed-length vector size / ABIVLen
// = fixed-length vector elt size * ScalVecNumElts / RVVBitsPerBlock
// =>
// ScalVecNumElts
// = (NumArrElt * fixed-length vector size * RVVBitsPerBlock) /
// (ABIVLen * fixed-length vector elt size)
// = NumArrElt * num fixed-length vector elt * RVVBitsPerBlock /
// ABIVLen
unsigned ScalVecNumElts = llvm::divideCeil(
NumArrElt * ArrEltTy->getNumElements() * llvm::RISCV::RVVBitsPerBlock,
ABIVLen);
VLSType = llvm::ScalableVectorType::get(ArrEltTy->getElementType(),
ScalVecNumElts);
// Check registers needed <= 8.
return llvm::divideCeil(
ScalVecNumElts *
ArrEltTy->getElementType()->getScalarSizeInBits(),
llvm::RISCV::RVVBitsPerBlock) <= 8;
}
// Turn them into scalable vector type or vector tuple type if legal.
if (NumElts == 1) {
// Handle single fixed-length vector.
VLSType = llvm::ScalableVectorType::get(
FixedVecTy->getElementType(),
llvm::divideCeil(FixedVecTy->getNumElements() *
llvm::RISCV::RVVBitsPerBlock,
ABIVLen));
return true;
}
return false;

// LMUL
// = fixed-length vector size / ABIVLen
// = 8 * I8EltCount / RVVBitsPerBlock
// =>
// I8EltCount
// = (fixed-length vector size * RVVBitsPerBlock) / (ABIVLen * 8)
unsigned I8EltCount =
llvm::divideCeil(FixedVecTy->getNumElements() *
FixedVecTy->getElementType()->getScalarSizeInBits() *
llvm::RISCV::RVVBitsPerBlock,
ABIVLen * 8);
VLSType = llvm::TargetExtType::get(
getVMContext(), "riscv.vector.tuple",
llvm::ScalableVectorType::get(llvm::Type::getInt8Ty(getVMContext()),
I8EltCount),
NumElts);
return true;
}

// Fixed-length RVV vectors are represented as scalable vectors in function
Expand Down
8 changes: 4 additions & 4 deletions clang/test/CodeGen/RISCV/riscv-vector-callingconv-llvm-ir.c
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,14 @@ void __attribute__((riscv_vls_cc)) test_st_i32x4_arr1(struct st_i32x4_arr1 arg)
// CHECK-LLVM: define dso_local riscv_vls_cc(256) void @test_st_i32x4_arr1_256(<vscale x 1 x i32> %arg)
void __attribute__((riscv_vls_cc(256))) test_st_i32x4_arr1_256(struct st_i32x4_arr1 arg) {}

// CHECK-LLVM: define dso_local riscv_vls_cc(128) void @test_st_i32x4_arr4(<vscale x 8 x i32> %arg)
// CHECK-LLVM: define dso_local riscv_vls_cc(128) void @test_st_i32x4_arr4(target("riscv.vector.tuple", <vscale x 8 x i8>, 4) %arg)
void __attribute__((riscv_vls_cc)) test_st_i32x4_arr4(struct st_i32x4_arr4 arg) {}
// CHECK-LLVM: define dso_local riscv_vls_cc(256) void @test_st_i32x4_arr4_256(<vscale x 4 x i32> %arg)
// CHECK-LLVM: define dso_local riscv_vls_cc(256) void @test_st_i32x4_arr4_256(target("riscv.vector.tuple", <vscale x 4 x i8>, 4) %arg)
void __attribute__((riscv_vls_cc(256))) test_st_i32x4_arr4_256(struct st_i32x4_arr4 arg) {}

// CHECK-LLVM: define dso_local riscv_vls_cc(128) void @test_st_i32x4_arr8(<vscale x 16 x i32> %arg)
// CHECK-LLVM: define dso_local riscv_vls_cc(128) void @test_st_i32x4_arr8(target("riscv.vector.tuple", <vscale x 8 x i8>, 8) %arg)
void __attribute__((riscv_vls_cc)) test_st_i32x4_arr8(struct st_i32x4_arr8 arg) {}
// CHECK-LLVM: define dso_local riscv_vls_cc(256) void @test_st_i32x4_arr8_256(<vscale x 8 x i32> %arg)
// CHECK-LLVM: define dso_local riscv_vls_cc(256) void @test_st_i32x4_arr8_256(target("riscv.vector.tuple", <vscale x 4 x i8>, 8) %arg)
void __attribute__((riscv_vls_cc(256))) test_st_i32x4_arr8_256(struct st_i32x4_arr8 arg) {}

// CHECK-LLVM: define dso_local riscv_vls_cc(128) void @test_st_i32x4x2(target("riscv.vector.tuple", <vscale x 8 x i8>, 2) %arg)
Expand Down
8 changes: 4 additions & 4 deletions clang/test/CodeGen/RISCV/riscv-vector-callingconv-llvm-ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,14 @@ typedef int __attribute__((vector_size(256))) int32x64_t;
// CHECK-LLVM: define dso_local riscv_vls_cc(256) void @_Z22test_st_i32x4_arr1_25613st_i32x4_arr1(<vscale x 1 x i32> %arg)
[[riscv::vls_cc(256)]] void test_st_i32x4_arr1_256(struct st_i32x4_arr1 arg) {}

// CHECK-LLVM: define dso_local riscv_vls_cc(128) void @_Z18test_st_i32x4_arr413st_i32x4_arr4(<vscale x 8 x i32> %arg)
// CHECK-LLVM: define dso_local riscv_vls_cc(128) void @_Z18test_st_i32x4_arr413st_i32x4_arr4(target("riscv.vector.tuple", <vscale x 8 x i8>, 4) %arg)
[[riscv::vls_cc]] void test_st_i32x4_arr4(struct st_i32x4_arr4 arg) {}
// CHECK-LLVM: define dso_local riscv_vls_cc(256) void @_Z22test_st_i32x4_arr4_25613st_i32x4_arr4(<vscale x 4 x i32> %arg)
// CHECK-LLVM: define dso_local riscv_vls_cc(256) void @_Z22test_st_i32x4_arr4_25613st_i32x4_arr4(target("riscv.vector.tuple", <vscale x 4 x i8>, 4) %arg)
[[riscv::vls_cc(256)]] void test_st_i32x4_arr4_256(struct st_i32x4_arr4 arg) {}

// CHECK-LLVM: define dso_local riscv_vls_cc(128) void @_Z18test_st_i32x4_arr813st_i32x4_arr8(<vscale x 16 x i32> %arg)
// CHECK-LLVM: define dso_local riscv_vls_cc(128) void @_Z18test_st_i32x4_arr813st_i32x4_arr8(target("riscv.vector.tuple", <vscale x 8 x i8>, 8) %arg)
[[riscv::vls_cc]] void test_st_i32x4_arr8(struct st_i32x4_arr8 arg) {}
// CHECK-LLVM: define dso_local riscv_vls_cc(256) void @_Z22test_st_i32x4_arr8_25613st_i32x4_arr8(<vscale x 8 x i32> %arg)
// CHECK-LLVM: define dso_local riscv_vls_cc(256) void @_Z22test_st_i32x4_arr8_25613st_i32x4_arr8(target("riscv.vector.tuple", <vscale x 4 x i8>, 8) %arg)
[[riscv::vls_cc(256)]] void test_st_i32x4_arr8_256(struct st_i32x4_arr8 arg) {}

// CHECK-LLVM: define dso_local riscv_vls_cc(128) void @_Z15test_st_i32x4x210st_i32x4x2(target("riscv.vector.tuple", <vscale x 8 x i8>, 2) %arg)
Expand Down
Loading