Skip to content

Commit 07e9d6e

Browse files
AdUhTkJmlanza
authored andcommitted
[CIR][CUDA] Support device-side printf (#1475)
The choice of adding a separate file imitates that of OG.
1 parent 28f9896 commit 07e9d6e

File tree

4 files changed

+148
-3
lines changed

4 files changed

+148
-3
lines changed

clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2340,12 +2340,14 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
23402340
llvm_unreachable("BI__builtin_load_halff NYI");
23412341

23422342
case Builtin::BI__builtin_printf:
2343-
llvm_unreachable("BI__builtin_printf NYI");
23442343
case Builtin::BIprintf:
2345-
if (getTarget().getTriple().isNVPTX() ||
2346-
getTarget().getTriple().isAMDGCN()) {
2344+
assert(E->getNumArgs() >= 1);
2345+
if (getTarget().getTriple().isAMDGCN()) {
23472346
llvm_unreachable("BIprintf NYI");
23482347
}
2348+
if (getTarget().getTriple().isNVPTX()) {
2349+
return RValue::get(emitNVPTXDevicePrintfCallExpr(E));
2350+
}
23492351
break;
23502352

23512353
case Builtin::BI__builtin_canonicalize:

clang/lib/CIR/CodeGen/CIRGenBuiltinNVPTX.cpp

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,96 @@ mlir::Value CIRGenFunction::emitNVPTXBuiltinExpr(unsigned builtinId,
8080
llvm_unreachable("NYI");
8181
}
8282
}
83+
84+
// vprintf takes two args: A format string, and a pointer to a buffer containing
85+
// the varargs.
86+
//
87+
// For example, the call
88+
//
89+
// printf("format string", arg1, arg2, arg3);
90+
//
91+
// is converted into something resembling
92+
//
93+
// struct Tmp {
94+
// Arg1 a1;
95+
// Arg2 a2;
96+
// Arg3 a3;
97+
// };
98+
// char* buf = alloca(sizeof(Tmp));
99+
// *(Tmp*)buf = {a1, a2, a3};
100+
// vprintf("format string", buf);
101+
//
102+
// `buf` is aligned to the max of {alignof(Arg1), ...}. Furthermore, each of
103+
// the args is itself aligned to its preferred alignment.
104+
//
105+
// Note that by the time this function runs, the arguments have already
106+
// undergone the standard C vararg promotion (short -> int, float -> double
107+
// etc). In this function we pack the arguments into the buffer described above.
108+
mlir::Value packArgsIntoNVPTXFormatBuffer(CIRGenFunction &cgf,
109+
const CallArgList &args,
110+
mlir::Location loc) {
111+
const CIRDataLayout &dataLayout = cgf.CGM.getDataLayout();
112+
CIRGenBuilderTy &builder = cgf.getBuilder();
113+
114+
if (args.size() <= 1)
115+
// If there are no arguments other than the format string,
116+
// pass a nullptr to vprintf.
117+
return builder.getNullPtr(cgf.VoidPtrTy, loc);
118+
119+
llvm::SmallVector<mlir::Type, 8> argTypes;
120+
for (auto arg : llvm::drop_begin(args))
121+
argTypes.push_back(arg.getRValue(cgf, loc).getScalarVal().getType());
122+
123+
// We can directly store the arguments into a struct, and the alignment
124+
// would automatically be correct. That's because vprintf does not
125+
// accept aggregates.
126+
mlir::Type allocaTy =
127+
cir::StructType::get(&cgf.getMLIRContext(), argTypes, /*packed=*/false,
128+
/*padded=*/false, StructType::Struct);
129+
mlir::Value alloca =
130+
cgf.CreateTempAlloca(allocaTy, loc, "printf_args", nullptr);
131+
132+
for (auto [i, arg] : llvm::enumerate(llvm::drop_begin(args))) {
133+
mlir::Value member =
134+
builder.createGetMember(loc, cir::PointerType::get(argTypes[i]), alloca,
135+
/*name=*/"", /*index=*/i);
136+
auto preferredAlign = clang::CharUnits::fromQuantity(
137+
dataLayout.getPrefTypeAlign(argTypes[i]).value());
138+
builder.createAlignedStore(loc, arg.getRValue(cgf, loc).getScalarVal(),
139+
member, preferredAlign);
140+
}
141+
142+
return builder.createBitcast(alloca, cgf.VoidPtrTy);
143+
}
144+
145+
mlir::Value
146+
CIRGenFunction::emitNVPTXDevicePrintfCallExpr(const CallExpr *expr) {
147+
assert(CGM.getTriple().isNVPTX());
148+
CallArgList args;
149+
emitCallArgs(args,
150+
expr->getDirectCallee()->getType()->getAs<FunctionProtoType>(),
151+
expr->arguments(), expr->getDirectCallee());
152+
153+
mlir::Location loc = getLoc(expr->getBeginLoc());
154+
155+
// Except the format string, no non-scalar arguments are allowed for
156+
// device-side printf.
157+
bool hasNonScalar =
158+
llvm::any_of(llvm::drop_begin(args), [&](const CallArg &A) {
159+
return !A.getRValue(*this, loc).isScalar();
160+
});
161+
if (hasNonScalar) {
162+
CGM.ErrorUnsupported(expr, "non-scalar args to printf");
163+
return builder.getConstInt(loc, SInt32Ty, 0);
164+
}
165+
166+
mlir::Value packedData = packArgsIntoNVPTXFormatBuffer(*this, args, loc);
167+
168+
// int vprintf(char *format, void *packedData);
169+
auto vprintf = CGM.createRuntimeFunction(
170+
FuncType::get({cir::PointerType::get(SInt8Ty), VoidPtrTy}, SInt32Ty),
171+
"vprintf");
172+
auto formatString = args[0].getRValue(*this, loc).getScalarVal();
173+
return builder.createCallOp(loc, vprintf, {formatString, packedData})
174+
.getResult();
175+
}

clang/lib/CIR/CodeGen/CIRGenFunction.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1478,6 +1478,8 @@ class CIRGenFunction : public CIRGenTypeCache {
14781478
mlir::Value emitX86BuiltinExpr(unsigned BuiltinID, const CallExpr *E);
14791479
mlir::Value emitNVPTXBuiltinExpr(unsigned builtinID, const CallExpr *expr);
14801480

1481+
mlir::Value emitNVPTXDevicePrintfCallExpr(const CallExpr *expr);
1482+
14811483
/// Given an expression with a pointer type, emit the value and compute our
14821484
/// best estimate of the alignment of the pointee.
14831485
///

clang/test/CIR/CodeGen/CUDA/printf.cu

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#include "../Inputs/cuda.h"
2+
3+
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir \
4+
// RUN: -fcuda-is-device -emit-cir -target-sdk-version=12.3 \
5+
// RUN: %s -o %t.cir
6+
// RUN: FileCheck --check-prefix=CIR-DEVICE --input-file=%t.cir %s
7+
8+
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir \
9+
// RUN: -fcuda-is-device -emit-llvm -target-sdk-version=12.3 \
10+
// RUN: %s -o %t.ll
11+
// RUN: FileCheck --check-prefix=LLVM-DEVICE --input-file=%t.ll %s
12+
13+
__device__ void printer() {
14+
printf("%d", 0);
15+
}
16+
17+
// CIR-DEVICE: cir.func @_Z7printerv() extra({{.*}}) {
18+
// CIR-DEVICE: %[[#Packed:]] = cir.alloca !ty_anon_struct
19+
// CIR-DEVICE: %[[#Zero:]] = cir.const #cir.int<0> : !s32i loc(#loc5)
20+
// CIR-DEVICE: %[[#Field0:]] = cir.get_member %0[0]
21+
// CIR-DEVICE: cir.store align(4) %[[#Zero]], %[[#Field0]]
22+
// CIR-DEVICE: %[[#Output:]] = cir.cast(bitcast, %[[#Packed]] : !cir.ptr<!ty_anon_struct>)
23+
// CIR-DEVICE: cir.call @vprintf(%{{.+}}, %[[#Output]])
24+
// CIR-DEVICE: cir.return
25+
// CIR-DEVICE: }
26+
27+
// LLVM-DEVICE: define dso_local void @_Z7printerv() {{.*}} {
28+
// LLVM-DEVICE: %[[#LLVMPacked:]] = alloca { i32 }, i64 1, align 8
29+
// LLVM-DEVICE: %[[#LLVMField0:]] = getelementptr { i32 }, ptr %[[#LLVMPacked]], i32 0, i32 0
30+
// LLVM-DEVICE: store i32 0, ptr %[[#LLVMField0]], align 4
31+
// LLVM-DEVICE: call i32 @vprintf(ptr @.str, ptr %[[#LLVMPacked]])
32+
// LLVM-DEVICE: ret void
33+
// LLVM-DEVICE: }
34+
35+
__device__ void no_extra() {
36+
printf("hello world");
37+
}
38+
39+
// CIR-DEVICE: cir.func @_Z8no_extrav() extra(#fn_attr) {
40+
// CIR-DEVICE: %[[#NULLPTR:]] = cir.const #cir.ptr<null>
41+
// CIR-DEVICE: cir.call @vprintf(%{{.+}}, %[[#NULLPTR]])
42+
// CIR-DEVICE: cir.return
43+
// CIR-DEVICE: }
44+
45+
// LLVM-DEVICE: define dso_local void @_Z8no_extrav() {{.*}} {
46+
// LLVM-DEVICE: call i32 @vprintf(ptr @.str.1, ptr null)
47+
// LLVM-DEVICE: ret void
48+
// LLVM-DEVICE: }

0 commit comments

Comments
 (0)