@@ -80,3 +80,96 @@ mlir::Value CIRGenFunction::emitNVPTXBuiltinExpr(unsigned builtinId,
80
80
llvm_unreachable (" NYI" );
81
81
}
82
82
}
83
+
84
+ // vprintf takes two args: A format string, and a pointer to a buffer containing
85
+ // the varargs.
86
+ //
87
+ // For example, the call
88
+ //
89
+ // printf("format string", arg1, arg2, arg3);
90
+ //
91
+ // is converted into something resembling
92
+ //
93
+ // struct Tmp {
94
+ // Arg1 a1;
95
+ // Arg2 a2;
96
+ // Arg3 a3;
97
+ // };
98
+ // char* buf = alloca(sizeof(Tmp));
99
+ // *(Tmp*)buf = {a1, a2, a3};
100
+ // vprintf("format string", buf);
101
+ //
102
+ // `buf` is aligned to the max of {alignof(Arg1), ...}. Furthermore, each of
103
+ // the args is itself aligned to its preferred alignment.
104
+ //
105
+ // Note that by the time this function runs, the arguments have already
106
+ // undergone the standard C vararg promotion (short -> int, float -> double
107
+ // etc). In this function we pack the arguments into the buffer described above.
108
+ mlir::Value packArgsIntoNVPTXFormatBuffer (CIRGenFunction &cgf,
109
+ const CallArgList &args,
110
+ mlir::Location loc) {
111
+ const CIRDataLayout &dataLayout = cgf.CGM .getDataLayout ();
112
+ CIRGenBuilderTy &builder = cgf.getBuilder ();
113
+
114
+ if (args.size () <= 1 )
115
+ // If there are no arguments other than the format string,
116
+ // pass a nullptr to vprintf.
117
+ return builder.getNullPtr (cgf.VoidPtrTy , loc);
118
+
119
+ llvm::SmallVector<mlir::Type, 8 > argTypes;
120
+ for (auto arg : llvm::drop_begin (args))
121
+ argTypes.push_back (arg.getRValue (cgf, loc).getScalarVal ().getType ());
122
+
123
+ // We can directly store the arguments into a struct, and the alignment
124
+ // would automatically be correct. That's because vprintf does not
125
+ // accept aggregates.
126
+ mlir::Type allocaTy =
127
+ cir::StructType::get (&cgf.getMLIRContext (), argTypes, /* packed=*/ false ,
128
+ /* padded=*/ false , StructType::Struct);
129
+ mlir::Value alloca =
130
+ cgf.CreateTempAlloca (allocaTy, loc, " printf_args" , nullptr );
131
+
132
+ for (auto [i, arg] : llvm::enumerate (llvm::drop_begin (args))) {
133
+ mlir::Value member =
134
+ builder.createGetMember (loc, cir::PointerType::get (argTypes[i]), alloca,
135
+ /* name=*/ " " , /* index=*/ i);
136
+ auto preferredAlign = clang::CharUnits::fromQuantity (
137
+ dataLayout.getPrefTypeAlign (argTypes[i]).value ());
138
+ builder.createAlignedStore (loc, arg.getRValue (cgf, loc).getScalarVal (),
139
+ member, preferredAlign);
140
+ }
141
+
142
+ return builder.createBitcast (alloca, cgf.VoidPtrTy );
143
+ }
144
+
145
+ mlir::Value
146
+ CIRGenFunction::emitNVPTXDevicePrintfCallExpr (const CallExpr *expr) {
147
+ assert (CGM.getTriple ().isNVPTX ());
148
+ CallArgList args;
149
+ emitCallArgs (args,
150
+ expr->getDirectCallee ()->getType ()->getAs <FunctionProtoType>(),
151
+ expr->arguments (), expr->getDirectCallee ());
152
+
153
+ mlir::Location loc = getLoc (expr->getBeginLoc ());
154
+
155
+ // Except the format string, no non-scalar arguments are allowed for
156
+ // device-side printf.
157
+ bool hasNonScalar =
158
+ llvm::any_of (llvm::drop_begin (args), [&](const CallArg &A) {
159
+ return !A.getRValue (*this , loc).isScalar ();
160
+ });
161
+ if (hasNonScalar) {
162
+ CGM.ErrorUnsupported (expr, " non-scalar args to printf" );
163
+ return builder.getConstInt (loc, SInt32Ty, 0 );
164
+ }
165
+
166
+ mlir::Value packedData = packArgsIntoNVPTXFormatBuffer (*this , args, loc);
167
+
168
+ // int vprintf(char *format, void *packedData);
169
+ auto vprintf = CGM.createRuntimeFunction (
170
+ FuncType::get ({cir::PointerType::get (SInt8Ty), VoidPtrTy}, SInt32Ty),
171
+ " vprintf" );
172
+ auto formatString = args[0 ].getRValue (*this , loc).getScalarVal ();
173
+ return builder.createCallOp (loc, vprintf, {formatString, packedData})
174
+ .getResult ();
175
+ }
0 commit comments