Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19085,6 +19085,13 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
/*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getStepIntrinsic(),
ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.step");
}
case Builtin::BI__builtin_hlsl_wave_active_count_bits: {
Value *OpExpr = EmitScalarExpr(E->getArg(0));
Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveCountBitsIntrinsic();
return EmitRuntimeCall(
Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID),
ArrayRef{OpExpr});
}
case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
// We don't define a SPIR-V intrinsic, instead it is a SPIR-V built-in
// defined in SPIRVBuiltins.td. So instead we manually get the matching name
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CodeGen/CGHLSLRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class CGHLSLRuntime {
GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddI8Packed, dot4add_i8packed)
GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddU8Packed, dot4add_u8packed)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveCountBits, wave_active_countbits)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlane)
GENERATE_HLSL_INTRINSIC_FUNCTION(FirstBitUHigh, firstbituhigh)
Expand Down
18 changes: 18 additions & 0 deletions clang/test/CodeGenHLSL/builtins/WaveActiveCountBits.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
// RUN: FileCheck %s -DTARGET=dx
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
// RUN: FileCheck %s -DTARGET=spv

// Test basic lowering to runtime function call.

// CHECK-LABEL: test_bool
int test_bool(bool expr) {
// CHECK: call {{.*}} @llvm.[[TARGET]].wave.active.countbits
return WaveActiveCountBits(expr);
}

// CHECK: declare i32 @llvm.[[TARGET]].wave.active.countbits(i1) #[[#attr:]]

// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}
18 changes: 18 additions & 0 deletions clang/test/SemaHLSL/BuiltIns/WaveActiveCountBits-errors.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify

int test_too_few_arg() {
return __builtin_hlsl_wave_active_count_bits();
// expected-error@-1 {{too few arguments to function call, expected 1, have 0}}
}

int test_too_many_arg(bool x) {
return __builtin_hlsl_wave_active_count_bits(x, x);
// expected-error@-1 {{too many arguments to function call, expected 1, have 2}}
}

struct S { float f; };

int test_bad_conversion(S x) {
return __builtin_hlsl_wave_active_count_bits(x);
// expected-error@-1 {{no viable conversion from 'S' to 'bool'}}
}
1 change: 1 addition & 0 deletions llvm/include/llvm/IR/IntrinsicsDirectX.td
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLV
def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
def int_dx_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_getlaneindex : DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
def int_dx_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/IR/IntrinsicsSPIRV.td
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ let TargetPrefix = "spv" in {
[IntrNoMem, Commutative] >;
def int_spv_dot4add_i8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
def int_spv_dot4add_u8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
def int_spv_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
def int_spv_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
Expand Down
9 changes: 9 additions & 0 deletions llvm/lib/Target/DirectX/DXIL.td
Original file line number Diff line number Diff line change
Expand Up @@ -880,3 +880,12 @@ def WaveGetLaneIndex : DXILOp<111, waveGetLaneIndex> {
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
}

def WaveAllBitCount : DXILOp<135, waveAllOp> {
let Doc = "returns the count of bits set to 1 across the wave";
let LLVMIntrinsic = int_dx_wave_active_countbits;
let arguments = [Int1Ty];
let result = Int32Ty;
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
}
36 changes: 36 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectSpvThreadId(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;

bool selectWaveActiveCountBits(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;

bool selectWaveReadLaneAt(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;

Expand Down Expand Up @@ -1917,6 +1920,37 @@ bool SPIRVInstructionSelector::selectSign(Register ResVReg,
return Result;
}

bool SPIRVInstructionSelector::selectWaveActiveCountBits(
Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const {
assert(I.getNumOperands() == 3);
assert(I.getOperand(2).isReg());
MachineBasicBlock &BB = *I.getParent();

SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
SPIRVType *BallotType = GR.getOrCreateSPIRVVectorType(IntTy, 4, I, TII);
Register BallotReg = MRI->createVirtualRegister(GR.getRegClass(BallotType));

bool Result =
BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpGroupNonUniformBallot))
.addDef(BallotReg)
.addUse(GR.getSPIRVTypeID(BallotType))
.addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII))
.addUse(I.getOperand(2).getReg())
.constrainAllUses(TII, TRI, RBI);

Result &=
BuildMI(BB, I, I.getDebugLoc(),
TII.get(SPIRV::OpGroupNonUniformBallotBitCount))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII))
.addImm(SPIRV::GroupOperation::Reduce)
.addUse(BallotReg)
.constrainAllUses(TII, TRI, RBI);

return Result;
}

bool SPIRVInstructionSelector::selectWaveReadLaneAt(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
Expand Down Expand Up @@ -2739,6 +2773,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
} break;
case Intrinsic::spv_saturate:
return selectSaturate(ResVReg, ResType, I);
case Intrinsic::spv_wave_active_countbits:
return selectWaveActiveCountBits(ResVReg, ResType, I);
case Intrinsic::spv_wave_is_first_lane: {
SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
return BuildMI(BB, I, I.getDebugLoc(),
Expand Down
10 changes: 10 additions & 0 deletions llvm/test/CodeGen/DirectX/WaveActiveCountBits.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s

define void @main(i1 %expr) {
entry:
; CHECK: call i32 @dx.op.waveAllOp(i32 135, i1 %expr)
%0 = call i32 @llvm.dx.wave.active.countbits(i1 %expr)
ret void
}

declare i32 @llvm.dx.wave.active.countbits(i1)
19 changes: 19 additions & 0 deletions llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveCountBits.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32v1.3-vulkan-unknown %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32v1.3-vulkan-unknown %s -o - -filetype=obj | spirv-val %}

; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
; CHECK-DAG: %[[#ballot_type:]] = OpTypeVector %[[#uint]] 4
; CHECK-DAG: %[[#bool:]] = OpTypeBool
; CHECK-DAG: %[[#scope:]] = OpConstant %[[#uint]] 3

; CHECK-LABEL: Begin function test_fun
; CHECK: %[[#bexpr:]] = OpFunctionParameter %[[#bool]]
define i32 @test_fun(i1 %expr) {
entry:
; CHECK: %[[#ballot:]] = OpGroupNonUniformBallot %[[#ballot_type]] %[[#scope]] %[[#bexpr]]
; CHECK: %[[#ret:]] = OpGroupNonUniformBallotBitCount %[[#uint]] %[[#scope]] Reduce %[[#ballot]]
%0 = call i32 @llvm.spv.wave.active.countbits(i1 %expr)
ret i32 %0
}

declare i32 @llvm.dx.wave.active.countbits(i1)
Loading