Skip to content

Commit 3573165

Browse files
committed
[DXIL][SPIRV] Lower WaveActiveCountBits intrinsic
- add codegen for llvm builtin to spirv/directx intrinsic in CGBuiltin.cpp - add lowering of spirv intrinsic to spirv backend in SPIRVInstructionSelector.cpp - add lowering of directx intrinsic to dxil op in DXIL.td - add test cases to illustrate passes - add test case for semantic analysis
1 parent bf30b6c commit 3573165

File tree

10 files changed

+125
-0
lines changed

10 files changed

+125
-0
lines changed

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19056,6 +19056,13 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
1905619056
/*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getStepIntrinsic(),
1905719057
ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.step");
1905819058
}
19059+
case Builtin::BI__builtin_hlsl_wave_active_count_bits: {
19060+
Value *OpExpr = EmitScalarExpr(E->getArg(0));
19061+
Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveCountBitsIntrinsic();
19062+
return EmitRuntimeCall(
19063+
Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID),
19064+
ArrayRef{OpExpr});
19065+
}
1905919066
case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
1906019067
// We don't define a SPIR-V intrinsic, instead it is a SPIR-V built-in
1906119068
// defined in SPIRVBuiltins.td. So instead we manually get the matching name

clang/lib/CodeGen/CGHLSLRuntime.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ class CGHLSLRuntime {
9191
GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
9292
GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddI8Packed, dot4add_i8packed)
9393
GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddU8Packed, dot4add_u8packed)
94+
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveCountBits, wave_active_countbits)
9495
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
9596
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlane)
9697
GENERATE_HLSL_INTRINSIC_FUNCTION(FirstBitUHigh, firstbituhigh)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
2+
// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
3+
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
4+
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
5+
// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
6+
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
7+
8+
// Test basic lowering to runtime function call.
9+
10+
// CHECK-LABEL: test_bool
11+
int test_bool(bool expr) {
12+
// CHECK-SPIRV: %[[#entry_tok:]] = call token @llvm.experimental.convergence.entry()
13+
// CHECK-SPIRV: %[[RET:.*]] = call spir_func i32 @llvm.spv.wave.active.countbits(i1 %{{.*}}) [ "convergencectrl"(token %[[#entry_tok]]) ]
14+
// CHECK-DXIL: %[[RET:.*]] = call i32 @llvm.dx.wave.active.countbits(i1 %{{.*}})
15+
// CHECK: ret i32 %[[RET]]
16+
return WaveActiveCountBits(expr);
17+
}
18+
19+
// CHECK-DXIL: declare i32 @llvm.dx.wave.active.countbits(i1) #[[#attr:]]
20+
// CHECK-SPIRV: declare i32 @llvm.spv.wave.active.countbits(i1) #[[#attr:]]
21+
22+
// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
2+
3+
int test_too_few_arg() {
4+
return __builtin_hlsl_wave_active_count_bits();
5+
// expected-error@-1 {{too few arguments to function call, expected 1, have 0}}
6+
}
7+
8+
int test_too_many_arg(bool x) {
9+
return __builtin_hlsl_wave_active_count_bits(x, x);
10+
// expected-error@-1 {{too many arguments to function call, expected 1, have 2}}
11+
}
12+
13+
struct S { float f; };
14+
15+
int test_bad_conversion(S x) {
16+
return __builtin_hlsl_wave_active_count_bits(x);
17+
// expected-error@-1 {{no viable conversion from 'S' to 'bool'}}
18+
}

llvm/include/llvm/IR/IntrinsicsDirectX.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLV
8686
def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
8787
def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
8888
def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
89+
def int_dx_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
8990
def int_dx_wave_getlaneindex : DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent, IntrNoMem]>;
9091
def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
9192
def int_dx_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;

llvm/include/llvm/IR/IntrinsicsSPIRV.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ let TargetPrefix = "spv" in {
8585
[IntrNoMem, Commutative] >;
8686
def int_spv_dot4add_i8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
8787
def int_spv_dot4add_u8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
88+
def int_spv_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
8889
def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
8990
def int_spv_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;
9091
def int_spv_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;

llvm/lib/Target/DirectX/DXIL.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -873,3 +873,12 @@ def WaveGetLaneIndex : DXILOp<111, waveGetLaneIndex> {
873873
let stages = [Stages<DXIL1_0, [all_stages]>];
874874
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
875875
}
876+
877+
def WaveAllBitCount : DXILOp<135, waveAllOp> {
878+
let Doc = "returns the count of bits set to 1 across the wave";
879+
let LLVMIntrinsic = int_dx_wave_active_countbits;
880+
let arguments = [Int1Ty];
881+
let result = Int32Ty;
882+
let stages = [Stages<DXIL1_0, [all_stages]>];
883+
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
884+
}

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
256256
bool selectSpvThreadId(Register ResVReg, const SPIRVType *ResType,
257257
MachineInstr &I) const;
258258

259+
bool selectWaveActiveCountBits(Register ResVReg, const SPIRVType *ResType,
260+
MachineInstr &I) const;
261+
259262
bool selectWaveReadLaneAt(Register ResVReg, const SPIRVType *ResType,
260263
MachineInstr &I) const;
261264

@@ -1917,6 +1920,38 @@ bool SPIRVInstructionSelector::selectSign(Register ResVReg,
19171920
return Result;
19181921
}
19191922

1923+
bool SPIRVInstructionSelector::selectWaveActiveCountBits(Register ResVReg,
1924+
const SPIRVType *ResType,
1925+
MachineInstr &I) const {
1926+
assert(I.getNumOperands() == 3);
1927+
assert(I.getOperand(2).isReg());
1928+
MachineBasicBlock &BB = *I.getParent();
1929+
1930+
Register BallotReg = MRI->createVirtualRegister(&SPIRV::IDRegClass);
1931+
SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
1932+
SPIRVType *BallotType = GR.getOrCreateSPIRVVectorType(IntTy, 4, I, TII);
1933+
1934+
bool Result =
1935+
BuildMI(BB, I, I.getDebugLoc(),
1936+
TII.get(SPIRV::OpGroupNonUniformBallot))
1937+
.addDef(BallotReg)
1938+
.addUse(GR.getSPIRVTypeID(BallotType))
1939+
.addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII))
1940+
.addUse(I.getOperand(2).getReg());
1941+
1942+
Result |=
1943+
BuildMI(BB, I, I.getDebugLoc(),
1944+
TII.get(SPIRV::OpGroupNonUniformBallotBitCount))
1945+
.addDef(ResVReg)
1946+
.addUse(GR.getSPIRVTypeID(ResType))
1947+
.addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII))
1948+
.addImm(0)
1949+
.addUse(BallotReg)
1950+
.constrainAllUses(TII, TRI, RBI);
1951+
1952+
return Result;
1953+
}
1954+
19201955
bool SPIRVInstructionSelector::selectWaveReadLaneAt(Register ResVReg,
19211956
const SPIRVType *ResType,
19221957
MachineInstr &I) const {
@@ -2739,6 +2774,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
27392774
} break;
27402775
case Intrinsic::spv_saturate:
27412776
return selectSaturate(ResVReg, ResType, I);
2777+
case Intrinsic::spv_wave_active_countbits:
2778+
return selectWaveActiveCountBits(ResVReg, ResType, I);
27422779
case Intrinsic::spv_wave_is_first_lane: {
27432780
SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
27442781
return BuildMI(BB, I, I.getDebugLoc(),
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s
2+
3+
define void @main(i1 %expr) {
4+
entry:
5+
; CHECK: call i32 @dx.op.waveAllOp(i32 135, i1 %expr)
6+
%0 = call i32 @llvm.dx.wave.active.countbits(i1 %expr)
7+
ret void
8+
}
9+
10+
declare i32 @llvm.dx.wave.active.countbits(i1)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32v1.3-vulkan-unknown %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32v1.3-vulkan-unknown %s -o - -filetype=obj | spirv-val %}
3+
4+
; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
5+
; CHECK-DAG: %[[#ballot_type:]] = OpTypeVector %[[#uint]] 4
6+
; CHECK-DAG: %[[#bool:]] = OpTypeBool
7+
; CHECK-DAG: %[[#scope:]] = OpConstant %[[#uint]] 3
8+
9+
; CHECK-LABEL: Begin function test_fun
10+
; CHECK: %[[#bexpr:]] = OpFunctionParameter %[[#bool]]
11+
define i32 @test_fun(i1 %expr) {
12+
entry:
13+
; CHECK: %[[#ballot:]] = OpGroupNonUniformBallot %[[#ballot_type]] %[[#scope]] %[[#bexpr]]
14+
; CHECK: %[[#ret:]] = OpGroupNonUniformBallotBitCount %[[#uint]] %[[#scope]] Reduce %[[#ballot]]
15+
%0 = call i32 @llvm.spv.wave.active.countbits(i1 %expr)
16+
ret i32 %0
17+
}
18+
19+
declare i32 @llvm.dx.wave.active.countbits(i1)

0 commit comments

Comments
 (0)