Skip to content

Commit a846f29

Browse files
authored
[CIR][CUDA] CallConvLowering for basic types in NVPTX (#1468)
There are some subtleties here. This is the code in OG: ```cpp // note: this is different from default ABI if (!RetTy->isScalarType()) return ABIArgInfo::getDirect(); ``` which says we should return structs directly. It's correct, has have the same behaviour as `nvcc`, and it obeys the PTX ABI as well. The comment dates back to 2013 (see [this commit](llvm/llvm-project@f9329ff) -- it didn't provide any explanation either), so I believe it's outdated. I didn't include this comment in the PR.
1 parent 7e4213f commit a846f29

File tree

4 files changed

+145
-3
lines changed

4 files changed

+145
-3
lines changed

clang/include/clang/CIR/Dialect/IR/CIRTypes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ class StructType
186186
};
187187

188188
bool isAnyFloatingPointType(mlir::Type t);
189+
bool isScalarType(mlir::Type t);
189190
bool isFPOrFPVectorTy(mlir::Type);
190191
bool isIntOrIntVectorTy(mlir::Type);
191192
} // namespace cir

clang/lib/CIR/Dialect/IR/CIRTypes.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -841,6 +841,12 @@ bool cir::isAnyFloatingPointType(mlir::Type t) {
841841
cir::FP80Type>(t);
842842
}
843843

844+
bool cir::isScalarType(mlir::Type ty) {
845+
return isa<cir::IntType, cir::BoolType, cir::SingleType, cir::DoubleType,
846+
cir::LongDoubleType, cir::FP16Type, cir::FP128Type, cir::FP80Type,
847+
cir::DataMemberType, cir::PointerType>(ty);
848+
}
849+
844850
//===----------------------------------------------------------------------===//
845851
// Floating-point and Float-point Vector type helpers
846852
//===----------------------------------------------------------------------===//

clang/lib/CIR/Dialect/Transforms/TargetLowering/Targets/NVPTX.cpp

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
#include "TargetInfo.h"
1313
#include "TargetLoweringInfo.h"
1414
#include "clang/CIR/ABIArgInfo.h"
15+
#include "clang/CIR/Dialect/IR/CIRTypes.h"
1516
#include "clang/CIR/MissingFeatures.h"
17+
#include "llvm/Support/Casting.h"
1618
#include "llvm/Support/ErrorHandling.h"
1719

1820
using ABIArgInfo = cir::ABIArgInfo;
@@ -31,9 +33,10 @@ class NVPTXABIInfo : public ABIInfo {
3133
NVPTXABIInfo(LowerTypes &lt) : ABIInfo(lt) {}
3234

3335
private:
34-
void computeInfo(LowerFunctionInfo &fi) const override {
35-
llvm_unreachable("NYI");
36-
}
36+
ABIArgInfo classifyReturnType(mlir::Type ty) const;
37+
ABIArgInfo classifyArgumentType(mlir::Type ty) const;
38+
39+
void computeInfo(LowerFunctionInfo &fi) const override;
3740
};
3841

3942
class NVPTXTargetLoweringInfo : public TargetLoweringInfo {
@@ -63,6 +66,48 @@ class NVPTXTargetLoweringInfo : public TargetLoweringInfo {
6366

6467
} // namespace
6568

69+
ABIArgInfo NVPTXABIInfo::classifyReturnType(mlir::Type ty) const {
70+
if (llvm::isa<VoidType>(ty))
71+
return ABIArgInfo::getIgnore();
72+
73+
if (getContext().getLangOpts().OpenMP)
74+
llvm_unreachable("NYI");
75+
76+
if (!isScalarType(ty))
77+
return ABIArgInfo::getDirect();
78+
79+
// OG treats enums as their underlying type.
80+
// This has already been done for CIR.
81+
82+
// Integers with size < 32 must be extended to 32 bits.
83+
// (See Section 3.3 of PTX ABI.)
84+
return (isPromotableIntegerTypeForABI(ty) ? ABIArgInfo::getExtend(ty)
85+
: ABIArgInfo::getDirect());
86+
}
87+
88+
ABIArgInfo NVPTXABIInfo::classifyArgumentType(mlir::Type ty) const {
89+
if (isAggregateTypeForABI(ty))
90+
llvm_unreachable("NYI");
91+
92+
if (auto intType = llvm::dyn_cast<IntType>(ty)) {
93+
if (intType.getWidth() > 128)
94+
llvm_unreachable("NYI");
95+
}
96+
97+
return (isPromotableIntegerTypeForABI(ty) ? ABIArgInfo::getExtend(ty)
98+
: ABIArgInfo::getDirect());
99+
}
100+
101+
void NVPTXABIInfo::computeInfo(LowerFunctionInfo &fi) const {
102+
if (!getCXXABI().classifyReturnType(fi))
103+
fi.getReturnInfo() = classifyReturnType(fi.getReturnType());
104+
105+
for (auto &&[count, argument] : llvm::enumerate(fi.arguments()))
106+
argument.info = count < fi.getNumRequiredArgs()
107+
? classifyArgumentType(argument.type)
108+
: ABIArgInfo::getDirect();
109+
}
110+
66111
std::unique_ptr<TargetLoweringInfo>
67112
createNVPTXTargetLoweringInfo(LowerModule &lowerModule) {
68113
return std::make_unique<NVPTXTargetLoweringInfo>(lowerModule.getTypes());
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
// RUN: %clang_cc1 -std=c++20 -triple nvptx-nvidia-cuda -fclangir \
2+
// RUN: -fclangir-call-conv-lowering -emit-cir-flat -mmlir \
3+
// RUN: --mlir-print-ir-after=cir-call-conv-lowering %s -o %t.cir
4+
// RUN: FileCheck --input-file=%t.cir %s
5+
6+
// Test call conv lowering for trivial cases. //
7+
8+
// CHECK: @_Z4Voidv()
9+
void Void(void) {
10+
// CHECK: cir.call @_Z4Voidv() : () -> ()
11+
Void();
12+
}
13+
14+
// CHECK: @_Z4Boolb(%arg0: !cir.bool {cir.zeroext} loc({{.+}})) -> (!cir.bool {cir.zeroext})
15+
bool Bool(bool a) {
16+
// CHECK: cir.call @_Z4Boolb({{.+}}) : (!cir.bool) -> !cir.bool
17+
return Bool(a);
18+
}
19+
20+
// CHECK: cir.func @_Z5UCharh(%arg0: !u8i {cir.zeroext} loc({{.+}})) -> (!u8i {cir.zeroext})
21+
unsigned char UChar(unsigned char c) {
22+
// CHECK: cir.call @_Z5UCharh(%{{.+}}) : (!u8i) -> !u8i
23+
return UChar(c);
24+
}
25+
26+
// CHECK: cir.func @_Z6UShortt(%arg0: !u16i {cir.zeroext} loc({{.+}})) -> (!u16i {cir.zeroext})
27+
unsigned short UShort(unsigned short s) {
28+
// CHECK: cir.call @_Z6UShortt(%{{.+}}) : (!u16i) -> !u16i
29+
return UShort(s);
30+
}
31+
32+
// CHECK: cir.func @_Z4UIntj(%arg0: !u32i loc({{.+}})) -> !u32i
33+
unsigned int UInt(unsigned int i) {
34+
// CHECK: cir.call @_Z4UIntj(%{{.+}}) : (!u32i) -> !u32i
35+
return UInt(i);
36+
}
37+
38+
// CHECK: cir.func @_Z5ULongm(%arg0: !u32i loc({{.+}})) -> !u32i
39+
unsigned long ULong(unsigned long l) {
40+
// CHECK: cir.call @_Z5ULongm(%{{.+}}) : (!u32i) -> !u32i
41+
return ULong(l);
42+
}
43+
44+
// CHECK: cir.func @_Z9ULongLongy(%arg0: !u64i loc({{.+}})) -> !u64i
45+
unsigned long long ULongLong(unsigned long long l) {
46+
// CHECK: cir.call @_Z9ULongLongy(%{{.+}}) : (!u64i) -> !u64i
47+
return ULongLong(l);
48+
}
49+
50+
// CHECK: cir.func @_Z4Chara(%arg0: !s8i {cir.signext} loc({{.+}})) -> (!s8i {cir.signext})
51+
char Char(signed char c) {
52+
// CHECK: cir.call @_Z4Chara(%{{.+}}) : (!s8i) -> !s8i
53+
return Char(c);
54+
}
55+
56+
// CHECK: cir.func @_Z5Shorts(%arg0: !s16i {cir.signext} loc({{.+}})) -> (!s16i {cir.signext})
57+
short Short(short s) {
58+
// CHECK: cir.call @_Z5Shorts(%{{.+}}) : (!s16i) -> !s16i
59+
return Short(s);
60+
}
61+
62+
// CHECK: cir.func @_Z3Inti(%arg0: !s32i loc({{.+}})) -> !s32i
63+
int Int(int i) {
64+
// CHECK: cir.call @_Z3Inti(%{{.+}}) : (!s32i) -> !s32i
65+
return Int(i);
66+
}
67+
68+
// CHECK: cir.func @_Z4Longl(%arg0: !s32i loc({{.+}})) -> !s32i
69+
long Long(long l) {
70+
// CHECK: cir.call @_Z4Longl(%{{.+}}) : (!s32i) -> !s32i
71+
return Long(l);
72+
}
73+
74+
// CHECK: cir.func @_Z8LongLongx(%arg0: !s64i loc({{.+}})) -> !s64i
75+
long long LongLong(long long l) {
76+
// CHECK: cir.call @_Z8LongLongx(%{{.+}}) : (!s64i) -> !s64i
77+
return LongLong(l);
78+
}
79+
80+
81+
// Check for structs.
82+
83+
struct Struct {
84+
int a, b, c, d, e;
85+
};
86+
87+
// CHECK: cir.func @_Z10StructFuncv() -> !ty_Struct
88+
Struct StructFunc() {
89+
return { 0, 1, 2, 3, 4 };
90+
}

0 commit comments

Comments
 (0)