Skip to content

Commit cab416d

Browse files
committed
riscv: support byval tail call arguments
1 parent 6c75b24 commit cab416d

File tree

4 files changed

+242
-44
lines changed

4 files changed

+242
-44
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 39 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23549,6 +23549,8 @@ SDValue RISCVTargetLowering::LowerFormalArguments(
2354923549
continue;
2355023550
}
2355123551
InVals.push_back(ArgValue);
23552+
if (Ins[InsIdx].Flags.isByVal())
23553+
RVFI->addIncomingByValArgs(ArgValue);
2355223554
}
2355323555

2355423556
if (any_of(ArgLocs,
@@ -23561,7 +23563,6 @@ SDValue RISCVTargetLowering::LowerFormalArguments(
2356123563
const TargetRegisterClass *RC = &RISCV::GPRRegClass;
2356223564
MachineFrameInfo &MFI = MF.getFrameInfo();
2356323565
MachineRegisterInfo &RegInfo = MF.getRegInfo();
23564-
RISCVMachineFunctionInfo *RVFI = MF.getInfo<RISCVMachineFunctionInfo>();
2356523566

2356623567
// Size of the vararg save area. For now, the varargs save area is either
2356723568
// zero or large enough to hold a0-a7.
@@ -23647,25 +23648,23 @@ bool RISCVTargetLowering::isEligibleForTailCallOptimization(
2364723648
if (CCInfo.getStackSize() > RVFI->getArgumentStackSize())
2364823649
return false;
2364923650

23650-
// Do not tail call opt if any parameters need to be passed indirectly.
23651-
// Since long doubles (fp128) and i128 are larger than 2*XLEN, they are
23652-
// passed indirectly. So the address of the value will be passed in a
23653-
// register, or if not available, then the address is put on the stack. In
23654-
// order to pass indirectly, space on the stack often needs to be allocated
23655-
// in order to store the value. In this case the CCInfo.getNextStackOffset()
23656-
// != 0 check is not enough and we need to check if any CCValAssign ArgsLocs
23657-
// are passed CCValAssign::Indirect.
23658-
for (auto &VA : ArgLocs)
23659-
if (VA.getLocInfo() == CCValAssign::Indirect)
23660-
return false;
23661-
2366223651
// Do not tail call opt if either caller or callee uses struct return
2366323652
// semantics.
2366423653
auto IsCallerStructRet = Caller.hasStructRetAttr();
2366523654
auto IsCalleeStructRet = Outs.empty() ? false : Outs[0].Flags.isSRet();
2366623655
if (IsCallerStructRet != IsCalleeStructRet)
2366723656
return false;
2366823657

23658+
// Do not tail call opt if caller's and callee's byval arguments do not match.
23659+
for (unsigned i = 0, j = 0; i < Outs.size(); i++) {
23660+
if (!Outs[i].Flags.isByVal())
23661+
continue;
23662+
if (j++ >= RVFI->getIncomingByValArgsSize())
23663+
return false;
23664+
if (RVFI->getIncomingByValArgs(i).getValueType() != Outs[i].ArgVT)
23665+
return false;
23666+
}
23667+
2366923668
// The callee has to preserve all registers the caller needs to preserve.
2367023669
const RISCVRegisterInfo *TRI = Subtarget.getRegisterInfo();
2367123670
const uint32_t *CallerPreserved = TRI->getCallPreservedMask(MF, CallerCC);
@@ -23709,6 +23708,7 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI,
2370923708
const CallBase *CB = CLI.CB;
2371023709

2371123710
MachineFunction &MF = DAG.getMachineFunction();
23711+
RISCVMachineFunctionInfo *RVFI = MF.getInfo<RISCVMachineFunctionInfo>();
2371223712
MachineFunction::CallSiteInfo CSInfo;
2371323713

2371423714
// Set type id for call site info.
@@ -23743,7 +23743,7 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI,
2374323743

2374423744
// Create local copies for byval args
2374523745
SmallVector<SDValue, 8> ByValArgs;
23746-
for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
23746+
for (unsigned i = 0, j = 0, e = Outs.size(); i != e; ++i) {
2374723747
ISD::ArgFlagsTy Flags = Outs[i].Flags;
2374823748
if (!Flags.isByVal())
2374923749
continue;
@@ -23752,16 +23752,27 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI,
2375223752
unsigned Size = Flags.getByValSize();
2375323753
Align Alignment = Flags.getNonZeroByValAlign();
2375423754

23755-
int FI =
23756-
MF.getFrameInfo().CreateStackObject(Size, Alignment, /*isSS=*/false);
23757-
SDValue FIPtr = DAG.getFrameIndex(FI, getPointerTy(DAG.getDataLayout()));
2375823755
SDValue SizeNode = DAG.getConstant(Size, DL, XLenVT);
23756+
SDValue Dst;
2375923757

23760-
Chain = DAG.getMemcpy(Chain, DL, FIPtr, Arg, SizeNode, Alignment,
23761-
/*IsVolatile=*/false,
23762-
/*AlwaysInline=*/false, /*CI*/ nullptr, IsTailCall,
23763-
MachinePointerInfo(), MachinePointerInfo());
23764-
ByValArgs.push_back(FIPtr);
23758+
if (IsTailCall) {
23759+
SDValue CallerArg = RVFI->getIncomingByValArgs(j++);
23760+
if (isa<GlobalAddressSDNode>(Arg) || isa<ExternalSymbolSDNode>(Arg) ||
23761+
isa<FrameIndexSDNode>(Arg))
23762+
Dst = CallerArg;
23763+
} else {
23764+
int FI =
23765+
MF.getFrameInfo().CreateStackObject(Size, Alignment, /*isSS=*/false);
23766+
Dst = DAG.getFrameIndex(FI, getPointerTy(DAG.getDataLayout()));
23767+
}
23768+
if (Dst) {
23769+
Chain =
23770+
DAG.getMemcpy(Chain, DL, Dst, Arg, SizeNode, Alignment,
23771+
/*IsVolatile=*/false,
23772+
/*AlwaysInline=*/false, /*CI=*/nullptr, std::nullopt,
23773+
MachinePointerInfo(), MachinePointerInfo());
23774+
ByValArgs.push_back(Dst);
23775+
}
2376523776
}
2376623777

2376723778
if (!IsTailCall)
@@ -23864,8 +23875,12 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI,
2386423875
}
2386523876

2386623877
// Use local copy if it is a byval arg.
23867-
if (Flags.isByVal())
23868-
ArgValue = ByValArgs[j++];
23878+
if (Flags.isByVal()) {
23879+
if (!IsTailCall || (isa<GlobalAddressSDNode>(ArgValue) ||
23880+
isa<ExternalSymbolSDNode>(ArgValue) ||
23881+
isa<FrameIndexSDNode>(ArgValue)))
23882+
ArgValue = ByValArgs[j++];
23883+
}
2386923884

2387023885
if (VA.isRegLoc()) {
2387123886
// Queue up the argument copies and emit them at the end.

llvm/lib/Target/RISCV/RISCVMachineFunctionInfo.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ class RISCVMachineFunctionInfo : public MachineFunctionInfo {
7070
/// being passed on the stack
7171
unsigned ArgumentStackSize = 0;
7272

73+
/// Incoming ByVal arguments
74+
SmallVector<SDValue, 8> IncomingByValArgs;
75+
7376
/// Is there any vector argument or return?
7477
bool IsVectorCall = false;
7578

@@ -150,6 +153,10 @@ class RISCVMachineFunctionInfo : public MachineFunctionInfo {
150153
unsigned getArgumentStackSize() const { return ArgumentStackSize; }
151154
void setArgumentStackSize(unsigned size) { ArgumentStackSize = size; }
152155

156+
void addIncomingByValArgs(SDValue Val) { IncomingByValArgs.push_back(Val); }
157+
SDValue &getIncomingByValArgs(int Idx) { return IncomingByValArgs[Idx]; }
158+
unsigned getIncomingByValArgsSize() { return IncomingByValArgs.size(); }
159+
153160
enum class PushPopKind { None = 0, StdExtZcmp, VendorXqccmp };
154161

155162
PushPopKind getPushPopKind(const MachineFunction &MF) const;

llvm/test/CodeGen/RISCV/musttail.ll

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,3 +393,179 @@ entry:
393393
musttail call void @sret_callee(ptr sret({ double, double }) align 8 %result)
394394
ret void
395395
}
396+
397+
%twenty_bytes = type { [5 x i32] }
398+
declare void @large_callee(%twenty_bytes* byval(%twenty_bytes) align 4)
399+
400+
; Functions with byval parameters can be tail-called, because the value is
401+
; actually passed in registers in the same way for the caller and callee.
402+
define void @large_caller(%twenty_bytes* byval(%twenty_bytes) align 4 %a) {
403+
; RV32-LABEL: large_caller:
404+
; RV32: # %bb.0: # %entry
405+
; RV32-NEXT: tail large_callee
406+
;
407+
; RV64-LABEL: large_caller:
408+
; RV64: # %bb.0: # %entry
409+
; RV64-NEXT: tail large_callee
410+
entry:
411+
musttail call void @large_callee(%twenty_bytes* byval(%twenty_bytes) align 4 %a)
412+
ret void
413+
}
414+
415+
; As above, but with some inline asm to test that the arguments in r4 is
416+
; re-loaded before the call.
417+
define void @large_caller_check_regs(%twenty_bytes* byval(%twenty_bytes) align 4 %a) nounwind {
418+
; RV32-LABEL: large_caller_check_regs:
419+
; RV32: # %bb.0: # %entry
420+
; RV32-NEXT: #APP
421+
; RV32-NEXT: #NO_APP
422+
; RV32-NEXT: tail large_callee
423+
;
424+
; RV64-LABEL: large_caller_check_regs:
425+
; RV64: # %bb.0: # %entry
426+
; RV64-NEXT: #APP
427+
; RV64-NEXT: #NO_APP
428+
; RV64-NEXT: tail large_callee
429+
entry:
430+
tail call void asm sideeffect "", "~{r4}"()
431+
musttail call void @large_callee(%twenty_bytes* byval(%twenty_bytes) align 4 %a)
432+
ret void
433+
}
434+
435+
; The IR for this one looks dodgy, because it has an alloca passed to a
436+
; musttail function, but it is passed as a byval argument, so will be copied
437+
; into the stack space allocated by @large_caller_new_value's caller, so is
438+
; valid.
439+
define void @large_caller_new_value(%twenty_bytes* byval(%twenty_bytes) align 4 %a) nounwind {
440+
; RV32-LABEL: large_caller_new_value:
441+
; RV32: # %bb.0: # %entry
442+
; RV32-NEXT: addi sp, sp, -32
443+
; RV32-NEXT: li a1, 1
444+
; RV32-NEXT: li a2, 2
445+
; RV32-NEXT: li a3, 3
446+
; RV32-NEXT: li a4, 4
447+
; RV32-NEXT: sw zero, 12(sp)
448+
; RV32-NEXT: sw a1, 16(sp)
449+
; RV32-NEXT: sw a2, 20(sp)
450+
; RV32-NEXT: sw a3, 24(sp)
451+
; RV32-NEXT: sw a4, 28(sp)
452+
; RV32-NEXT: sw a4, 16(a0)
453+
; RV32-NEXT: sw zero, 0(a0)
454+
; RV32-NEXT: sw a1, 4(a0)
455+
; RV32-NEXT: sw a2, 8(a0)
456+
; RV32-NEXT: sw a3, 12(a0)
457+
; RV32-NEXT: addi sp, sp, 32
458+
; RV32-NEXT: tail large_callee
459+
;
460+
; RV64-LABEL: large_caller_new_value:
461+
; RV64: # %bb.0: # %entry
462+
; RV64-NEXT: addi sp, sp, -32
463+
; RV64-NEXT: li a1, 1
464+
; RV64-NEXT: li a2, 2
465+
; RV64-NEXT: li a3, 3
466+
; RV64-NEXT: li a4, 4
467+
; RV64-NEXT: sw zero, 12(sp)
468+
; RV64-NEXT: sw a1, 16(sp)
469+
; RV64-NEXT: sw a2, 20(sp)
470+
; RV64-NEXT: sw a3, 24(sp)
471+
; RV64-NEXT: sw a4, 28(sp)
472+
; RV64-NEXT: sw a4, 16(a0)
473+
; RV64-NEXT: sw zero, 0(a0)
474+
; RV64-NEXT: sw a1, 4(a0)
475+
; RV64-NEXT: sw a2, 8(a0)
476+
; RV64-NEXT: sw a3, 12(a0)
477+
; RV64-NEXT: addi sp, sp, 32
478+
; RV64-NEXT: tail large_callee
479+
entry:
480+
%y = alloca %twenty_bytes, align 4
481+
store i32 0, ptr %y, align 4
482+
%0 = getelementptr inbounds i8, ptr %y, i32 4
483+
store i32 1, ptr %0, align 4
484+
%1 = getelementptr inbounds i8, ptr %y, i32 8
485+
store i32 2, ptr %1, align 4
486+
%2 = getelementptr inbounds i8, ptr %y, i32 12
487+
store i32 3, ptr %2, align 4
488+
%3 = getelementptr inbounds i8, ptr %y, i32 16
489+
store i32 4, ptr %3, align 4
490+
musttail call void @large_callee(%twenty_bytes* byval(%twenty_bytes) align 4 %y)
491+
ret void
492+
}
493+
494+
declare void @two_byvals_callee(%twenty_bytes* byval(%twenty_bytes) align 4, %twenty_bytes* byval(%twenty_bytes) align 4)
495+
define void @swap_byvals(%twenty_bytes* byval(%twenty_bytes) align 4 %a, %twenty_bytes* byval(%twenty_bytes) align 4 %b) {
496+
; RV32-LABEL: swap_byvals:
497+
; RV32: # %bb.0: # %entry
498+
; RV32-NEXT: mv a2, a0
499+
; RV32-NEXT: mv a0, a1
500+
; RV32-NEXT: mv a1, a2
501+
; RV32-NEXT: tail two_byvals_callee
502+
;
503+
; RV64-LABEL: swap_byvals:
504+
; RV64: # %bb.0: # %entry
505+
; RV64-NEXT: mv a2, a0
506+
; RV64-NEXT: mv a0, a1
507+
; RV64-NEXT: mv a1, a2
508+
; RV64-NEXT: tail two_byvals_callee
509+
entry:
510+
musttail call void @two_byvals_callee(%twenty_bytes* byval(%twenty_bytes) align 4 %b, %twenty_bytes* byval(%twenty_bytes) align 4 %a)
511+
ret void
512+
}
513+
514+
; A forwarded byval arg, but in a different argument register, so it needs to
515+
; be moved between registers first. This can't be musttail because of the
516+
; different signatures, but is still tail-called as an optimisation.
517+
declare void @shift_byval_callee(%twenty_bytes* byval(%twenty_bytes) align 4)
518+
define void @shift_byval(i32 %a, %twenty_bytes* byval(%twenty_bytes) align 4 %b) {
519+
; RV32-LABEL: shift_byval:
520+
; RV32: # %bb.0: # %entry
521+
; RV32-NEXT: mv a0, a1
522+
; RV32-NEXT: tail shift_byval_callee
523+
;
524+
; RV64-LABEL: shift_byval:
525+
; RV64: # %bb.0: # %entry
526+
; RV64-NEXT: mv a0, a1
527+
; RV64-NEXT: tail shift_byval_callee
528+
entry:
529+
tail call void @shift_byval_callee(%twenty_bytes* byval(%twenty_bytes) align 4 %b)
530+
ret void
531+
}
532+
533+
; A global object passed to a byval argument, so it must be copied, but doesn't
534+
; need a stack temporary.
535+
@large_global = external global %twenty_bytes
536+
define void @large_caller_from_global(%twenty_bytes* byval(%twenty_bytes) align 4 %a) {
537+
; RV32-LABEL: large_caller_from_global:
538+
; RV32: # %bb.0: # %entry
539+
; RV32-NEXT: lui a1, %hi(large_global)
540+
; RV32-NEXT: addi a1, a1, %lo(large_global)
541+
; RV32-NEXT: lw a2, 16(a1)
542+
; RV32-NEXT: sw a2, 16(a0)
543+
; RV32-NEXT: lw a2, 12(a1)
544+
; RV32-NEXT: sw a2, 12(a0)
545+
; RV32-NEXT: lw a2, 8(a1)
546+
; RV32-NEXT: sw a2, 8(a0)
547+
; RV32-NEXT: lw a2, 4(a1)
548+
; RV32-NEXT: sw a2, 4(a0)
549+
; RV32-NEXT: lw a1, 0(a1)
550+
; RV32-NEXT: sw a1, 0(a0)
551+
; RV32-NEXT: tail large_callee
552+
;
553+
; RV64-LABEL: large_caller_from_global:
554+
; RV64: # %bb.0: # %entry
555+
; RV64-NEXT: lui a1, %hi(large_global)
556+
; RV64-NEXT: addi a1, a1, %lo(large_global)
557+
; RV64-NEXT: lw a2, 16(a1)
558+
; RV64-NEXT: sw a2, 16(a0)
559+
; RV64-NEXT: lw a2, 12(a1)
560+
; RV64-NEXT: sw a2, 12(a0)
561+
; RV64-NEXT: lw a2, 8(a1)
562+
; RV64-NEXT: sw a2, 8(a0)
563+
; RV64-NEXT: lw a2, 4(a1)
564+
; RV64-NEXT: sw a2, 4(a0)
565+
; RV64-NEXT: lw a1, 0(a1)
566+
; RV64-NEXT: sw a1, 0(a0)
567+
; RV64-NEXT: tail large_callee
568+
entry:
569+
musttail call void @large_callee(%twenty_bytes* byval(%twenty_bytes) align 4 @large_global)
570+
ret void
571+
}

0 commit comments

Comments
 (0)