1212// http://docs.nvidia.com/cuda/parallel-thread-execution/#state-spaces
1313//
1414// Kernel parameters are read-only and accessible only via ld.param
15- // instruction, directly or via a pointer. Pointers to kernel
16- // arguments can't be converted to generic address space.
15+ // instruction, directly or via a pointer.
1716//
1817// Device function parameters are directly accessible via
1918// ld.param/st.param, but taking the address of one returns a pointer
5453// ...
5554// }
5655//
57- // 2. Convert pointers in a byval kernel parameter to pointers in the global
58- // address space. As #2, it allows NVPTX to emit more ld/st.global. E.g.,
56+ // 2. Convert byval kernel parameters to pointers in the param address space
57+ // (so that NVPTX emits ld/st.param). Convert pointers *within* a byval
58+ // kernel parameter to pointers in the global address space. This allows
59+ // NVPTX to emit ld/st.global.
5960//
6061// struct S {
6162// int *x;
6869//
6970// "b" points to the global address space. In the IR level,
7071//
71- // define void @foo({i32*, i32*}* byval %input) {
72- // %b_ptr = getelementptr {i32*, i32* }, {i32*, i32*}* %input, i64 0, i32 1
73- // %b = load i32*, i32** %b_ptr
72+ // define void @foo(ptr byval %input) {
73+ // %b_ptr = getelementptr {ptr, ptr }, ptr %input, i64 0, i32 1
74+ // %b = load ptr, ptr %b_ptr
7475// ; use %b
7576// }
7677//
7778// becomes
7879//
7980// define void @foo({i32*, i32*}* byval %input) {
80- // %b_ptr = getelementptr {i32*, i32*}, {i32*, i32*}* %input, i64 0, i32 1
81- // %b = load i32*, i32** %b_ptr
82- // %b_global = addrspacecast i32* %b to i32 addrspace(1)*
83- // %b_generic = addrspacecast i32 addrspace(1)* %b_global to i32*
81+ // %b_param = addrspacecat ptr %input to ptr addrspace(101)
82+ // %b_ptr = getelementptr {ptr, ptr}, ptr addrspace(101) %b_param, i64 0, i32 1
83+ // %b = load ptr, ptr addrspace(101) %b_ptr
84+ // %b_global = addrspacecast ptr %b to ptr addrspace(1)
8485// ; use %b_generic
8586// }
8687//
88+ // Create a local copy of kernel byval parameters used in a way that *might* mutate
89+ // the parameter, by storing it in an alloca. Mutations to "grid_constant" parameters
90+ // are undefined behaviour, and don't require local copies.
91+ //
92+ // define void @foo(ptr byval(%struct.s) align 4 %input) {
93+ // store i32 42, ptr %input
94+ // ret void
95+ // }
96+ //
97+ // becomes
98+ //
99+ // define void @foo(ptr byval(%struct.s) align 4 %input) #1 {
100+ // %input1 = alloca %struct.s, align 4
101+ // %input2 = addrspacecast ptr %input to ptr addrspace(101)
102+ // %input3 = load %struct.s, ptr addrspace(101) %input2, align 4
103+ // store %struct.s %input3, ptr %input1, align 4
104+ // store i32 42, ptr %input1, align 4
105+ // ret void
106+ // }
107+ //
108+ // If %input were passed to a device function, or written to memory,
109+ // conservatively assume that %input gets mutated, and create a local copy.
110+ //
111+ // Convert param pointers to grid_constant byval kernel parameters that are
112+ // passed into calls (device functions, intrinsics, inline asm), or otherwise
113+ // "escape" (into stores/ptrtoints) to the generic address space, using the
114+ // `nvvm.ptr.param.to.gen` intrinsic, so that NVPTX emits cvta.param
115+ // (available for sm70+)
116+ //
117+ // define void @foo(ptr byval(%struct.s) %input) {
118+ // ; %input is a grid_constant
119+ // %call = call i32 @escape(ptr %input)
120+ // ret void
121+ // }
122+ //
123+ // becomes
124+ //
125+ // define void @foo(ptr byval(%struct.s) %input) {
126+ // %input1 = addrspacecast ptr %input to ptr addrspace(101)
127+ // ; the following intrinsic converts pointer to generic. We don't use an addrspacecast
128+ // ; to prevent generic -> param -> generic from getting cancelled out
129+ // %input1.gen = call ptr @llvm.nvvm.ptr.param.to.gen.p0.p101(ptr addrspace(101) %input1)
130+ // %call = call i32 @escape(ptr %input1.gen)
131+ // ret void
132+ // }
133+ //
87134// TODO: merge this pass with NVPTXInferAddressSpaces so that other passes don't
88135// cancel the addrspacecast pair this pass emits.
89136// ===----------------------------------------------------------------------===//
@@ -166,19 +213,22 @@ INITIALIZE_PASS_END(NVPTXLowerArgs, "nvptx-lower-args",
166213// ones in parameter AS, so we can access them using ld.param.
167214// =============================================================================
168215
169- // Replaces the \p OldUser instruction with the same in parameter AS.
170- // Only Load and GEP are supported.
171- static void convertToParamAS(Value *OldUser, Value *Param) {
172- Instruction *I = dyn_cast<Instruction>(OldUser);
173- assert (I && " OldUser must be an instruction" );
216+ // For Loads, replaces the \p OldUse of the pointer with a Use of the same
217+ // pointer in parameter AS.
218+ // For "escapes" (to memory, a function call, or a ptrtoint), cast the OldUse to
219+ // generic using cvta.param.
220+ static void convertToParamAS(Use *OldUse, Value *Param, bool GridConstant) {
221+ Instruction *I = dyn_cast<Instruction>(OldUse->getUser ());
222+ assert (I && " OldUse must be in an instruction" );
174223 struct IP {
224+ Use *OldUse;
175225 Instruction *OldInstruction;
176226 Value *NewParam;
177227 };
178- SmallVector<IP> ItemsToConvert = {{I, Param}};
228+ SmallVector<IP> ItemsToConvert = {{OldUse, I, Param}};
179229 SmallVector<Instruction *> InstructionsToDelete;
180230
181- auto CloneInstInParamAS = [](const IP &I) -> Value * {
231+ auto CloneInstInParamAS = [GridConstant ](const IP &I) -> Value * {
182232 if (auto *LI = dyn_cast<LoadInst>(I.OldInstruction )) {
183233 LI->setOperand (0 , I.NewParam );
184234 return LI;
@@ -202,6 +252,43 @@ static void convertToParamAS(Value *OldUser, Value *Param) {
202252 // Just pass through the argument, the old ASC is no longer needed.
203253 return I.NewParam ;
204254 }
255+
256+ if (GridConstant) {
257+ auto GetParamAddrCastToGeneric =
258+ [](Value *Addr, Instruction *OriginalUser) -> Value * {
259+ PointerType *ReturnTy =
260+ PointerType::get (OriginalUser->getContext (), ADDRESS_SPACE_GENERIC);
261+ Function *CvtToGen = Intrinsic::getDeclaration (
262+ OriginalUser->getModule (), Intrinsic::nvvm_ptr_param_to_gen,
263+ {ReturnTy, PointerType::get (OriginalUser->getContext (),
264+ ADDRESS_SPACE_PARAM)});
265+
266+ // Cast param address to generic address space
267+ Value *CvtToGenCall =
268+ CallInst::Create (CvtToGen, Addr, Addr->getName () + " .gen" ,
269+ OriginalUser->getIterator ());
270+ return CvtToGenCall;
271+ };
272+
273+ if (auto *CI = dyn_cast<CallInst>(I.OldInstruction )) {
274+ I.OldUse ->set (GetParamAddrCastToGeneric (I.NewParam , CI));
275+ return CI;
276+ }
277+ if (auto *SI = dyn_cast<StoreInst>(I.OldInstruction )) {
278+ // byval address is being stored, cast it to generic
279+ if (SI->getValueOperand () == I.OldUse ->get ())
280+ SI->setOperand (0 , GetParamAddrCastToGeneric (I.NewParam , SI));
281+ return SI;
282+ }
283+ if (auto *PI = dyn_cast<PtrToIntInst>(I.OldInstruction )) {
284+ if (PI->getPointerOperand () == I.OldUse ->get ())
285+ PI->setOperand (0 , GetParamAddrCastToGeneric (I.NewParam , PI));
286+ return PI;
287+ }
288+ llvm_unreachable (
289+ " Instruction unsupported even for grid_constant argument" );
290+ }
291+
205292 llvm_unreachable (" Unsupported instruction" );
206293 };
207294
@@ -213,8 +300,8 @@ static void convertToParamAS(Value *OldUser, Value *Param) {
213300 // We've created a new instruction. Queue users of the old instruction to
214301 // be converted and the instruction itself to be deleted. We can't delete
215302 // the old instruction yet, because it's still in use by a load somewhere.
216- for (Value *V : I.OldInstruction ->users ())
217- ItemsToConvert.push_back ({cast<Instruction>(V ), NewInst});
303+ for (Use &U : I.OldInstruction ->uses ())
304+ ItemsToConvert.push_back ({&U, cast<Instruction>(U. getUser () ), NewInst});
218305
219306 InstructionsToDelete.push_back (I.OldInstruction );
220307 }
@@ -272,6 +359,7 @@ static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
272359 SmallVector<Load> Loads;
273360 std::queue<LoadContext> Worklist;
274361 Worklist.push ({ArgInParamAS, 0 });
362+ bool IsGridConstant = isParamGridConstant (*Arg);
275363
276364 while (!Worklist.empty ()) {
277365 LoadContext Ctx = Worklist.front ();
@@ -303,8 +391,14 @@ static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
303391 continue ;
304392 }
305393
394+ // supported for grid_constant
395+ if (IsGridConstant &&
396+ (isa<CallInst>(CurUser) || isa<StoreInst>(CurUser) ||
397+ isa<PtrToIntInst>(CurUser)))
398+ continue ;
399+
306400 llvm_unreachable (" All users must be one of: load, "
307- " bitcast, getelementptr. " );
401+ " bitcast, getelementptr, call, store, ptrtoint " );
308402 }
309403 }
310404
@@ -317,49 +411,59 @@ static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
317411
318412void NVPTXLowerArgs::handleByValParam (const NVPTXTargetMachine &TM,
319413 Argument *Arg) {
414+ bool IsGridConstant = isParamGridConstant (*Arg);
320415 Function *Func = Arg->getParent ();
321416 BasicBlock::iterator FirstInst = Func->getEntryBlock ().begin ();
322417 Type *StructType = Arg->getParamByValType ();
323418 assert (StructType && " Missing byval type" );
324419
325- auto IsALoadChain = [&](Value *Start) {
420+ auto AreSupportedUsers = [&](Value *Start) {
326421 SmallVector<Value *, 16 > ValuesToCheck = {Start};
327- auto IsALoadChainInstr = [](Value *V) -> bool {
422+ auto IsSupportedUse = [IsGridConstant ](Value *V) -> bool {
328423 if (isa<GetElementPtrInst>(V) || isa<BitCastInst>(V) || isa<LoadInst>(V))
329424 return true ;
330425 // ASC to param space are OK, too -- we'll just strip them.
331426 if (auto *ASC = dyn_cast<AddrSpaceCastInst>(V)) {
332427 if (ASC->getDestAddressSpace () == ADDRESS_SPACE_PARAM)
333428 return true ;
334429 }
430+ // Simple calls and stores are supported for grid_constants
431+ // writes to these pointers are undefined behaviour
432+ if (IsGridConstant &&
433+ (isa<CallInst>(V) || isa<StoreInst>(V) || isa<PtrToIntInst>(V)))
434+ return true ;
335435 return false ;
336436 };
337437
338438 while (!ValuesToCheck.empty ()) {
339439 Value *V = ValuesToCheck.pop_back_val ();
340- if (!IsALoadChainInstr (V)) {
440+ if (!IsSupportedUse (V)) {
341441 LLVM_DEBUG (dbgs () << " Need a "
342442 << (isParamGridConstant (*Arg) ? " cast " : " copy " )
343443 << " of " << *Arg << " because of " << *V << " \n " );
344444 (void )Arg;
345445 return false ;
346446 }
347- if (!isa<LoadInst>(V))
447+ if (!isa<LoadInst>(V) && !isa<CallInst>(V) && !isa<StoreInst>(V) &&
448+ !isa<PtrToIntInst>(V))
348449 llvm::append_range (ValuesToCheck, V->users ());
349450 }
350451 return true ;
351452 };
352453
353- if (llvm::all_of (Arg->users (), IsALoadChain )) {
454+ if (llvm::all_of (Arg->users (), AreSupportedUsers )) {
354455 // Convert all loads and intermediate operations to use parameter AS and
355456 // skip creation of a local copy of the argument.
356- SmallVector<User *, 16 > UsersToUpdate (Arg->users ());
457+ SmallVector<Use *, 16 > UsesToUpdate;
458+ for (Use &U : Arg->uses ())
459+ UsesToUpdate.push_back (&U);
460+
357461 Value *ArgInParamAS = new AddrSpaceCastInst (
358462 Arg, PointerType::get (StructType, ADDRESS_SPACE_PARAM), Arg->getName (),
359463 FirstInst);
360- for (Value *V : UsersToUpdate )
361- convertToParamAS (V , ArgInParamAS);
362- LLVM_DEBUG (dbgs () << " No need to copy " << *Arg << " \n " );
464+ for (Use *U : UsesToUpdate )
465+ convertToParamAS (U , ArgInParamAS, IsGridConstant );
466+ LLVM_DEBUG (dbgs () << " No need to copy or cast " << *Arg << " \n " );
363467
364468 const auto *TLI =
365469 cast<NVPTXTargetLowering>(TM.getSubtargetImpl ()->getTargetLowering ());
@@ -376,16 +480,11 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
376480 // temporary copy. When a pointer might have escaped, conservatively replace
377481 // all of its uses (which might include a device function call) with a cast
378482 // to the generic address space.
379- // TODO: only cast byval grid constant parameters at use points that need
380- // generic address (e.g., merging parameter pointers with other address
381- // space, or escaping to call-sites, inline-asm, memory), and use the
382- // parameter address space for normal loads.
383483 IRBuilder<> IRB (&Func->getEntryBlock ().front ());
384484
385485 // Cast argument to param address space
386- auto *CastToParam =
387- cast<AddrSpaceCastInst>(IRB.CreateAddrSpaceCast (
388- Arg, IRB.getPtrTy (ADDRESS_SPACE_PARAM), Arg->getName () + " .param" ));
486+ auto *CastToParam = cast<AddrSpaceCastInst>(IRB.CreateAddrSpaceCast (
487+ Arg, IRB.getPtrTy (ADDRESS_SPACE_PARAM), Arg->getName () + " .param" ));
389488
390489 // Cast param address to generic address space. We do not use an
391490 // addrspacecast to generic here, because, LLVM considers `Arg` to be in the
0 commit comments