Skip to content

Commit 04af578

Browse files
committed
Added WaveActiveBitOr
1 parent 893b1d4 commit 04af578

File tree

16 files changed

+251
-1
lines changed

16 files changed

+251
-1
lines changed

clang/include/clang/Basic/Builtins.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4999,6 +4999,12 @@ def HLSLWaveActiveAnyTrue : LangBuiltin<"HLSL_LANG"> {
49994999
let Prototype = "bool(bool)";
50005000
}
50015001

5002+
def HLSLWaveActiveBitOr : LangBuiltin<"HLSL_LANG"> {
5003+
let Spellings = ["__builtin_hlsl_wave_active_bit_or"];
5004+
let Attributes = [NoThrow, Const];
5005+
let Prototype = "void (...)";
5006+
}
5007+
50025008
def HLSLWaveActiveCountBits : LangBuiltin<"HLSL_LANG"> {
50035009
let Spellings = ["__builtin_hlsl_wave_active_count_bits"];
50045010
let Attributes = [NoThrow, Const];

clang/lib/CodeGen/CGHLSLBuiltins.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,15 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
690690
return EmitRuntimeCall(
691691
Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID), {Op});
692692
}
693+
case Builtin::BI__builtin_hlsl_wave_active_bit_or: {
694+
Value *Op = EmitScalarExpr(E->getArg(0));
695+
assert(Op->getType()->hasUnsignedIntegerRepresentation() &&
696+
"Intrinsic WaveActiveBitOr operand must have a unsigned integer representation");
697+
698+
Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveBitOrIntrinsic();
699+
return EmitRuntimeCall(
700+
Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID), {Op});
701+
}
693702
case Builtin::BI__builtin_hlsl_wave_active_count_bits: {
694703
Value *OpExpr = EmitScalarExpr(E->getArg(0));
695704
Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveCountBitsIntrinsic();

clang/lib/CodeGen/CGHLSLRuntime.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ class CGHLSLRuntime {
113113
GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddU8Packed, dot4add_u8packed)
114114
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAllTrue, wave_all)
115115
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAnyTrue, wave_any)
116+
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveBitOr, wave_reduce_or)
116117
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveCountBits, wave_active_countbits)
117118
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
118119
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveGetLaneCount, wave_get_lane_count)

clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2498,6 +2498,36 @@ __attribute__((convergent)) double3 WaveReadLaneAt(double3, uint32_t);
24982498
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
24992499
__attribute__((convergent)) double4 WaveReadLaneAt(double4, uint32_t);
25002500

2501+
//===----------------------------------------------------------------------===//
2502+
// WaveActiveBitOr builtins
2503+
//===----------------------------------------------------------------------===//
2504+
2505+
_HLSL_AVAILABILITY(shadermodel, 6.0)
2506+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or)
2507+
__attribute__((convergent)) uint WaveActiveBitOr(uint);
2508+
_HLSL_AVAILABILITY(shadermodel, 6.0)
2509+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or)
2510+
__attribute__((convergent)) uint2 WaveActiveBitOr(uint2);
2511+
_HLSL_AVAILABILITY(shadermodel, 6.0)
2512+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or)
2513+
__attribute__((convergent)) uint3 WaveActiveBitOr(uint3);
2514+
_HLSL_AVAILABILITY(shadermodel, 6.0)
2515+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or)
2516+
__attribute__((convergent)) uint4 WaveActiveBitOr(uint4);
2517+
2518+
_HLSL_AVAILABILITY(shadermodel, 6.0)
2519+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or)
2520+
__attribute__((convergent)) uint64_t WaveActiveBitOr(uint64_t);
2521+
_HLSL_AVAILABILITY(shadermodel, 6.0)
2522+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or)
2523+
__attribute__((convergent)) uint64_t2 WaveActiveBitOr(uint64_t2);
2524+
_HLSL_AVAILABILITY(shadermodel, 6.0)
2525+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or)
2526+
__attribute__((convergent)) uint64_t3 WaveActiveBitOr(uint64_t3);
2527+
_HLSL_AVAILABILITY(shadermodel, 6.0)
2528+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or)
2529+
__attribute__((convergent)) uint64_t4 WaveActiveBitOr(uint64_t4);
2530+
25012531
//===----------------------------------------------------------------------===//
25022532
// WaveActiveMax builtins
25032533
//===----------------------------------------------------------------------===//

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3211,6 +3211,29 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
32113211
TheCall->setType(ArgTyExpr);
32123212
break;
32133213
}
3214+
case Builtin::BI__builtin_hlsl_wave_active_bit_or: {
3215+
if (SemaRef.checkArgCount(TheCall, 1))
3216+
return true;
3217+
3218+
// Ensure input expr type is a scalar/vector and the same as the return type
3219+
if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
3220+
return true;
3221+
if (CheckWaveActive(&SemaRef, TheCall))
3222+
return true;
3223+
3224+
// Ensure expression parameter type can be interpreted as a uint
3225+
ExprResult Expr = TheCall->getArg(0);
3226+
QualType ArgTyExpr = Expr.get()->getType();
3227+
if (!ArgTyExpr->isIntegerType()) {
3228+
SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
3229+
diag::err_typecheck_convert_incompatible)
3230+
<< ArgTyExpr << SemaRef.Context.UnsignedIntTy << 1 << 0 << 0;
3231+
return true;
3232+
}
3233+
3234+
TheCall->setType(ArgTyExpr);
3235+
break;
3236+
}
32143237
// Note these are llvm builtins that we want to catch invalid intrinsic
32153238
// generation. Normal handling of these builitns will occur elsewhere.
32163239
case Builtin::BI__builtin_elementwise_bitreverse: {
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
2+
// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
3+
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
4+
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
5+
// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
6+
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
7+
8+
// Test basic lowering to runtime function call.
9+
10+
// CHECK-LABEL: test_uint
11+
uint test_uint(uint expr) {
12+
// CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.reduce.or.i32([[TY]] %[[#]])
13+
// CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.reduce.or.i32([[TY]] %[[#]])
14+
// CHECK: ret [[TY]] %[[RET]]
15+
return WaveActiveBitOr(expr);
16+
}
17+
18+
// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.reduce.or.i32([[TY]]) #[[#attr:]]
19+
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.reduce.or.i32([[TY]]) #[[#attr:]]
20+
21+
// CHECK-LABEL: test_uint64_t
22+
uint64_t test_uint64_t(uint64_t expr) {
23+
// CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.reduce.or.i64([[TY]] %[[#]])
24+
// CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.reduce.or.i64([[TY]] %[[#]])
25+
// CHECK: ret [[TY]] %[[RET]]
26+
return WaveActiveBitOr(expr);
27+
}
28+
29+
// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.reduce.or.i64([[TY]]) #[[#attr:]]
30+
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.reduce.or.i64([[TY]]) #[[#attr:]]
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
2+
3+
uint test_too_few_arg() {
4+
return __builtin_hlsl_wave_active_bit_or();
5+
// expected-error@-1 {{too few arguments to function call, expected 1, have 0}}
6+
}
7+
8+
uint2 test_too_many_arg(uint2 p0) {
9+
return __builtin_hlsl_wave_active_bit_or(p0, p0);
10+
// expected-error@-1 {{too many arguments to function call, expected 1, have 2}}
11+
}
12+
13+
bool test_expr_bool_type_check(bool p0) {
14+
return __builtin_hlsl_wave_active_bit_or(p0);
15+
// expected-error@-1 {{invalid operand of type 'bool'}}
16+
}
17+
18+
float test_expr_float_type_check(float p0) {
19+
return __builtin_hlsl_wave_active_bit_or(p0);
20+
// expected-error@-1 {{invalid operand of type 'float'}}
21+
}
22+
23+
bool2 test_expr_bool_vec_type_check(bool2 p0) {
24+
return __builtin_hlsl_wave_active_bit_or(p0);
25+
// expected-error@-1 {{invalid operand of type 'bool2' (aka 'vector<bool, 2>')}}
26+
}
27+
28+
float2 test_expr_float_type_check(float2 p0) {
29+
return __builtin_hlsl_wave_active_bit_or(p0);
30+
// expected-error@-1 {{invalid operand of type 'float2' (aka 'vector<float, 2>')}}
31+
}
32+
33+
struct S { float f; };
34+
35+
S test_expr_struct_type_check(S p0) {
36+
return __builtin_hlsl_wave_active_bit_or(p0);
37+
// expected-error@-1 {{invalid operand of type 'S' where a scalar or vector is required}}
38+
}

llvm/include/llvm/IR/IntrinsicsDirectX.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def int_dx_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1
151151
def int_dx_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
152152
def int_dx_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
153153
def int_dx_wave_getlaneindex : DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent, IntrNoMem]>;
154+
def int_dx_wave_reduce_or : DefaultAttrsIntrinsic<[llvm_anyint_ty], [llvm_anyint_ty], [IntrConvergent, IntrNoMem]>;
154155
def int_dx_wave_reduce_max : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
155156
def int_dx_wave_reduce_umax : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
156157
def int_dx_wave_reduce_sum : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;

llvm/include/llvm/IR/IntrinsicsSPIRV.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]
120120
def int_spv_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
121121
def int_spv_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
122122
def int_spv_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
123+
def int_spv_wave_reduce_or : DefaultAttrsIntrinsic<[llvm_anyint_ty], [llvm_anyint_ty], [IntrConvergent, IntrNoMem]>;
123124
def int_spv_wave_reduce_umax : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
124125
def int_spv_wave_reduce_max : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
125126
def int_spv_wave_reduce_sum : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
@@ -136,7 +137,7 @@ def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]
136137
def int_spv_sclamp : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
137138
def int_spv_nclamp : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
138139

139-
// Create resource handle given the binding information. Returns a
140+
// Create resource handle given the binding information. Returns a
140141
// type appropriate for the kind of resource given the set id, binding id,
141142
// array size of the binding, as well as an index and an indicator
142143
// whether that index may be non-uniform.

llvm/lib/Target/DirectX/DXIL.td

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,10 @@ defvar WaveOpKind_Product = 1;
316316
defvar WaveOpKind_Min = 2;
317317
defvar WaveOpKind_Max = 3;
318318

319+
defvar WaveBitOpKind_And = 0;
320+
defvar WaveBitOpKind_Or = 1;
321+
defvar WaveBitOpKind_Xor = 2;
322+
319323
defvar SignedOpKind_Signed = 0;
320324
defvar SignedOpKind_Unsigned = 1;
321325

@@ -1069,6 +1073,24 @@ def WaveActiveOp : DXILOp<119, waveActiveOp> {
10691073
let attributes = [Attributes<DXIL1_0, []>];
10701074
}
10711075

1076+
def WaveActiveBit : DXILOp<120, waveActiveBit> {
1077+
let Doc = "returns the result of the operation across waves";
1078+
let intrinsics = [
1079+
IntrinSelect<int_dx_wave_reduce_or,
1080+
[
1081+
IntrinArgIndex<0>, IntrinArgI8<WaveBitOpKind_Or>,
1082+
]>,
1083+
];
1084+
1085+
let arguments = [OverloadTy, Int8Ty];
1086+
let result = OverloadTy;
1087+
let overloads = [
1088+
Overloads<DXIL1_0, [Int32Ty, Int64Ty]>
1089+
];
1090+
let stages = [Stages<DXIL1_0, [all_stages]>];
1091+
let attributes = [Attributes<DXIL1_0, []>];
1092+
}
1093+
10721094
def WaveAllBitCount : DXILOp<135, waveAllOp> {
10731095
let Doc = "returns the count of bits set to 1 across the wave";
10741096
let intrinsics = [IntrinSelect<int_dx_wave_active_countbits>];

0 commit comments

Comments
 (0)