Skip to content

Commit d3a7f9d

Browse files
committed
Added support for 2 kernel query builtins
1 parent 37edd2c commit d3a7f9d

File tree

7 files changed

+180
-10
lines changed

7 files changed

+180
-10
lines changed

llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp

Lines changed: 62 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -372,18 +372,15 @@ static MachineInstr *getBlockStructInstr(Register ParamReg,
372372
// We expect the following sequence of instructions:
373373
// %0:_(pN) = G_INTRINSIC_W_SIDE_EFFECTS intrinsic(@llvm.spv.alloca)
374374
// or = G_GLOBAL_VALUE @block_literal_global
375-
// %1:_(pN) = G_INTRINSIC_W_SIDE_EFFECTS intrinsic(@llvm.spv.bitcast), %0
376-
// %2:_(p4) = G_ADDRSPACE_CAST %1:_(pN)
375+
// %1:_(p4) = G_ADDRSPACE_CAST %0:_(pN)
377376
MachineInstr *MI = MRI->getUniqueVRegDef(ParamReg);
378377
assert(MI->getOpcode() == TargetOpcode::G_ADDRSPACE_CAST &&
379378
MI->getOperand(1).isReg());
380-
Register BitcastReg = MI->getOperand(1).getReg();
381-
MachineInstr *BitcastMI = MRI->getUniqueVRegDef(BitcastReg);
382-
assert(isSpvIntrinsic(*BitcastMI, Intrinsic::spv_bitcast) &&
383-
BitcastMI->getOperand(2).isReg());
384-
Register ValueReg = BitcastMI->getOperand(2).getReg();
385-
MachineInstr *ValueMI = MRI->getUniqueVRegDef(ValueReg);
386-
return ValueMI;
379+
Register PtrReg = MI->getOperand(1).getReg();
380+
MachineInstr *PtrMI = MRI->getUniqueVRegDef(PtrReg);
381+
assert(PtrMI->getOpcode() == TargetOpcode::G_GLOBAL_VALUE ||
382+
isSpvIntrinsic(*PtrMI, Intrinsic::spv_alloca));
383+
return PtrMI;
387384
}
388385

389386
// Return an integer constant corresponding to the given register and
@@ -2509,6 +2506,59 @@ static bool buildEnqueueKernel(const SPIRV::IncomingCall *Call,
25092506
return true;
25102507
}
25112508

2509+
static bool buildNDRangeSubGroup(const SPIRV::IncomingCall *Call,
2510+
unsigned Opcode, MachineIRBuilder &MIRBuilder,
2511+
SPIRVGlobalRegistry *GR) {
2512+
MachineRegisterInfo *MRI = MIRBuilder.getMRI();
2513+
const DataLayout &DL = MIRBuilder.getDataLayout();
2514+
2515+
auto MIB = MIRBuilder.buildInstr(Opcode)
2516+
.addDef(Call->ReturnRegister)
2517+
.addUse(GR->getSPIRVTypeID(Call->ReturnType))
2518+
.addUse(Call->Arguments[0]);
2519+
unsigned int BlockFIdx = 1;
2520+
MachineInstr *BlockMI = getBlockStructInstr(Call->Arguments[BlockFIdx], MRI);
2521+
assert(BlockMI->getOpcode() == TargetOpcode::G_GLOBAL_VALUE);
2522+
// Invoke: Pointer to invoke function.
2523+
Register BlockFReg = BlockMI->getOperand(0).getReg();
2524+
MIB.addUse(BlockFReg);
2525+
MRI->setRegClass(BlockFReg, &SPIRV::pIDRegClass);
2526+
2527+
Register BlockLiteralReg = Call->Arguments[BlockFIdx + 1];
2528+
// Param: Pointer to block literal.
2529+
MIB.addUse(BlockLiteralReg);
2530+
BlockMI = MRI->getUniqueVRegDef(BlockLiteralReg);
2531+
Register BlockMIReg =
2532+
stripAddrspaceCast(BlockMI->getOperand(1).getReg(), *MRI);
2533+
BlockMI = MRI->getUniqueVRegDef(BlockMIReg);
2534+
2535+
if (BlockMI->getOpcode() == TargetOpcode::G_GLOBAL_VALUE) {
2536+
// Size and align are given explicitly here.
2537+
const GlobalValue *GV = BlockMI->getOperand(1).getGlobal();
2538+
2539+
const GlobalVariable *BlockGV = dyn_cast<GlobalVariable>(GV);
2540+
assert(BlockGV->hasInitializer() &&
2541+
"Block literal should have an initializer");
2542+
const Constant *Init = BlockGV->getInitializer();
2543+
const ConstantStruct *CS = dyn_cast<ConstantStruct>(Init);
2544+
// Extract fields
2545+
const ConstantInt *SizeConst = dyn_cast<ConstantInt>(CS->getOperand(0));
2546+
const ConstantInt *AlignConst = dyn_cast<ConstantInt>(CS->getOperand(1));
2547+
uint64_t BlockSize = SizeConst->getZExtValue();
2548+
uint64_t BlockAlign = AlignConst->getZExtValue();
2549+
MIB.addUse(buildConstantIntReg32(BlockSize, MIRBuilder, GR));
2550+
MIB.addUse(buildConstantIntReg32(BlockAlign, MIRBuilder, GR));
2551+
} else {
2552+
Type *PType = const_cast<Type *>(getBlockStructType(BlockLiteralReg, MRI));
2553+
// Fallback to default if not found
2554+
MIB.addUse(
2555+
buildConstantIntReg32(DL.getTypeStoreSize(PType), MIRBuilder, GR));
2556+
MIB.addUse(buildConstantIntReg32(DL.getPrefTypeAlign(PType).value(),
2557+
MIRBuilder, GR));
2558+
}
2559+
return true;
2560+
}
2561+
25122562
static bool generateEnqueueInst(const SPIRV::IncomingCall *Call,
25132563
MachineIRBuilder &MIRBuilder,
25142564
SPIRVGlobalRegistry *GR) {
@@ -2544,6 +2594,9 @@ static bool generateEnqueueInst(const SPIRV::IncomingCall *Call,
25442594
return buildNDRange(Call, MIRBuilder, GR);
25452595
case SPIRV::OpEnqueueKernel:
25462596
return buildEnqueueKernel(Call, MIRBuilder, GR);
2597+
case SPIRV::OpGetKernelNDrangeSubGroupCount:
2598+
case SPIRV::OpGetKernelNDrangeMaxSubGroupSize:
2599+
return buildNDRangeSubGroup(Call, Opcode, MIRBuilder, GR);
25472600
default:
25482601
return false;
25492602
}

llvm/lib/Target/SPIRV/SPIRVBuiltins.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,9 @@ defm : DemangledNativeBuiltin<"__spirv_GetDefaultQueue", OpenCL_std, Enqueue, 0,
671671
defm : DemangledNativeBuiltin<"ndrange_1D", OpenCL_std, Enqueue, 1, 3, OpBuildNDRange>;
672672
defm : DemangledNativeBuiltin<"ndrange_2D", OpenCL_std, Enqueue, 1, 3, OpBuildNDRange>;
673673
defm : DemangledNativeBuiltin<"ndrange_3D", OpenCL_std, Enqueue, 1, 3, OpBuildNDRange>;
674+
defm : DemangledNativeBuiltin<"__get_kernel_sub_group_count_for_ndrange_impl", OpenCL_std, Enqueue, 2, 2, OpGetKernelNDrangeSubGroupCount>;
675+
defm : DemangledNativeBuiltin<"__get_kernel_max_sub_group_size_for_ndrange_impl", OpenCL_std, Enqueue, 2, 2, OpGetKernelNDrangeMaxSubGroupSize>;
676+
674677

675678
// Spec constant builtin records:
676679
defm : DemangledNativeBuiltin<"__spirv_SpecConstant", OpenCL_std, SpecConstant, 2, 2, OpSpecConstant>;

llvm/lib/Target/SPIRV/SPIRVInstrInfo.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -759,6 +759,10 @@ def OpGetDefaultQueue: Op<303, (outs ID:$res), (ins TYPE:$type),
759759
"$res = OpGetDefaultQueue $type">;
760760
def OpBuildNDRange: Op<304, (outs ID:$res), (ins TYPE:$type, ID:$GWS, ID:$LWS, ID:$GWO),
761761
"$res = OpBuildNDRange $type $GWS $LWS $GWO">;
762+
def OpGetKernelNDrangeSubGroupCount: Op<304, (outs ID:$res), (ins TYPE:$type, ID:$NDR, ID:$Invoke, ID:$Param, ID:$ParamSize, ID:$ParamAlign),
763+
"$res = OpGetKernelNDrangeSubGroupCount $type $NDR $Invoke $Param $ParamSize $ParamAlign">;
764+
def OpGetKernelNDrangeMaxSubGroupSize: Op<304, (outs ID:$res), (ins TYPE:$type, ID:$NDR, ID:$Invoke, ID:$Param, ID:$ParamSize, ID:$ParamAlign),
765+
"$res = OpGetKernelNDrangeMaxSubGroupSize $type $NDR $Invoke $Param $ParamSize $ParamAlign">;
762766

763767
// TODO: 3.42.23. Pipe Instructions
764768

llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1848,6 +1848,11 @@ void addInstrRequirements(const MachineInstr &MI,
18481848
Reqs.addCapability(SPIRV::Capability::TernaryBitwiseFunctionINTEL);
18491849
break;
18501850
}
1851+
case SPIRV::OpGetKernelNDrangeMaxSubGroupSize:
1852+
case SPIRV::OpGetKernelNDrangeSubGroupCount: {
1853+
Reqs.addCapability(SPIRV::Capability::DeviceEnqueue);
1854+
break;
1855+
}
18511856

18521857
default:
18531858
break;

llvm/lib/Target/SPIRV/SPIRVUtils.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,16 @@ bool isEntryPoint(const Function &F) {
488488
return false;
489489
}
490490

491+
Register stripAddrspaceCast(Register Reg, const MachineRegisterInfo &MRI) {
492+
while (true) {
493+
MachineInstr *Def = MRI.getVRegDef(Reg);
494+
if (!Def || Def->getOpcode() != TargetOpcode::G_ADDRSPACE_CAST)
495+
break;
496+
Reg = Def->getOperand(1).getReg(); // Unwrap the cast
497+
}
498+
return Reg;
499+
}
500+
491501
Type *parseBasicTypeName(StringRef &TypeName, LLVMContext &Ctx) {
492502
TypeName.consume_front("atomic_");
493503
if (TypeName.consume_front("void"))

llvm/lib/Target/SPIRV/SPIRVUtils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ bool isSpecialOpaqueType(const Type *Ty);
257257

258258
// Check if the function is an SPIR-V entry point
259259
bool isEntryPoint(const Function &F);
260-
260+
Register stripAddrspaceCast(Register Reg, const MachineRegisterInfo &MRI);
261261
// Parse basic scalar type name, substring TypeName, and return LLVM type.
262262
Type *parseBasicTypeName(StringRef &TypeName, LLVMContext &Ctx);
263263

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
2+
3+
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
4+
target triple = "spir-unknown-unknown"
5+
6+
%struct.ndrange_t = type { i32 }
7+
%1 = type <{ i32, i32 }>
8+
9+
@__block_literal_global = internal addrspace(1) constant { i32, i32 } { i32 8, i32 4 }, align 4
10+
@__block_literal_global.1 = internal addrspace(1) constant { i32, i32 } { i32 8, i32 4 }, align 4
11+
12+
; CHECK-DAG: %[[#Int32Ty:]] = OpTypeInt 32 0
13+
; CHECK-DAG: %[[#C4:]] = OpConstant %[[#Int32Ty]] 4
14+
; CHECK-DAG: %[[#C8:]] = OpConstant %[[#Int32Ty]] 8
15+
; CHECK-DAG: %[[#NDRangeTy:]] = OpTypeStruct %[[#Int32Ty]]
16+
; CHECK-DAG: %[[#NDRangePtrTy:]] = OpTypePointer Function %[[#NDRangeTy]]
17+
18+
; Function Attrs: convergent noinline nounwind optnone
19+
define spir_kernel void @device_side_enqueue() #0 !kernel_arg_addr_space !2 !kernel_arg_access_qual !2 !kernel_arg_type !2 !kernel_arg_base_type !2 !kernel_arg_type_qual !2 {
20+
entry:
21+
22+
; CHECK: %[[#NDRange:]] = OpVariable %[[#NDRangePtrTy]]
23+
24+
%ndrange = alloca %struct.ndrange_t, align 4
25+
26+
; CHECK: %[[#BlockLit1:]] = OpPtrCastToGeneric %[[#]] %[[#]]
27+
; CHECK: %[[#]] = OpGetKernelNDrangeMaxSubGroupSize %[[#Int32Ty]] %[[#NDRange]] %[[#]] %[[#BlockLit1]] %[[#C8]] %[[#C4]]
28+
29+
%0 = call i32 @__get_kernel_max_sub_group_size_for_ndrange_impl(ptr %ndrange, ptr addrspace(4) addrspacecast (ptr @__device_side_enqueue_block_invoke_kernel to ptr addrspace(4)), ptr addrspace(4) addrspacecast (ptr addrspace(1) @__block_literal_global to ptr addrspace(4)))
30+
31+
; CHECK: %[[#BlockLit2:]] = OpPtrCastToGeneric %[[#]] %[[#]]
32+
; CHECK: %[[#]] = OpGetKernelNDrangeSubGroupCount %[[#Int32Ty]] %[[#NDRange]] %[[#]] %[[#BlockLit2]] %[[#C8]] %[[#C4]]
33+
34+
%1 = call i32 @__get_kernel_sub_group_count_for_ndrange_impl(ptr %ndrange, ptr addrspace(4) addrspacecast (ptr @__device_side_enqueue_block_invoke_1_kernel to ptr addrspace(4)), ptr addrspace(4) addrspacecast (ptr addrspace(1) @__block_literal_global.1 to ptr addrspace(4)))
35+
ret void
36+
}
37+
38+
declare i32 @__get_kernel_preferred_work_group_size_multiple_impl(ptr addrspace(4), ptr addrspace(4))
39+
40+
; Function Attrs: convergent noinline nounwind optnone
41+
define internal spir_func void @__device_side_enqueue_block_invoke(ptr addrspace(4) %.block_descriptor) #1 {
42+
entry:
43+
%.block_descriptor.addr = alloca ptr addrspace(4), align 4
44+
%block.addr = alloca ptr addrspace(4), align 4
45+
store ptr addrspace(4) %.block_descriptor, ptr %.block_descriptor.addr, align 4
46+
store ptr addrspace(4) %.block_descriptor, ptr %block.addr, align 4
47+
ret void
48+
}
49+
50+
; Function Attrs: nounwind
51+
define internal spir_kernel void @__device_side_enqueue_block_invoke_kernel(ptr addrspace(4)) #2 {
52+
entry:
53+
call void @__device_side_enqueue_block_invoke(ptr addrspace(4) %0)
54+
ret void
55+
}
56+
57+
declare i32 @__get_kernel_max_sub_group_size_for_ndrange_impl(ptr, ptr addrspace(4), ptr addrspace(4))
58+
59+
; Function Attrs: convergent noinline nounwind optnone
60+
define internal spir_func void @__device_side_enqueue_block_invoke_1(ptr addrspace(4) %.block_descriptor) #1 {
61+
entry:
62+
%.block_descriptor.addr = alloca ptr addrspace(4), align 4
63+
%block.addr = alloca ptr addrspace(4), align 4
64+
store ptr addrspace(4) %.block_descriptor, ptr %.block_descriptor.addr, align 4
65+
store ptr addrspace(4) %.block_descriptor, ptr %block.addr, align 4
66+
ret void
67+
}
68+
69+
; Function Attrs: nounwind
70+
define internal spir_kernel void @__device_side_enqueue_block_invoke_1_kernel(ptr addrspace(4)) #2 {
71+
entry:
72+
call void @__device_side_enqueue_block_invoke_1(ptr addrspace(4) %0)
73+
ret void
74+
}
75+
76+
declare i32 @__get_kernel_sub_group_count_for_ndrange_impl(ptr, ptr addrspace(4), ptr addrspace(4))
77+
78+
attributes #0 = { convergent noinline nounwind optnone "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "uniform-work-group-size"="false" "unsafe-fp-math"="false" "use-soft-float"="false" }
79+
attributes #1 = { convergent noinline nounwind optnone "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" }
80+
attributes #2 = { nounwind }
81+
attributes #3 = { argmemonly nounwind }
82+
83+
!llvm.module.flags = !{!0}
84+
!opencl.enable.FP_CONTRACT = !{}
85+
!opencl.ocl.version = !{!1}
86+
!opencl.spir.version = !{!1}
87+
!opencl.used.extensions = !{!2}
88+
!opencl.used.optional.core.features = !{!2}
89+
!opencl.compiler.options = !{!2}
90+
!llvm.ident = !{!3}
91+
92+
!0 = !{i32 1, !"wchar_size", i32 4}
93+
!1 = !{i32 2, i32 0}
94+
!2 = !{}
95+
!3 = !{!"clang version 7.0.0"}

0 commit comments

Comments
 (0)