From 0e77b1ae638fbd0f1bbcf4970258d20aaf5e59d7 Mon Sep 17 00:00:00 2001 From: Kai Huang Date: Fri, 14 Nov 2025 02:14:23 +0800 Subject: [PATCH 1/2] [HLSL][DXIL][SPRIV] WavePrefixSum intrinsic support This commit adds WavePrefixSum intrinsic support to HLSL. --- clang/include/clang/Basic/Builtins.td | 6 ++ clang/lib/CodeGen/CGHLSLBuiltins.cpp | 7 ++ clang/lib/CodeGen/CGHLSLRuntime.h | 1 + .../lib/Headers/hlsl/hlsl_alias_intrinsics.h | 99 +++++++++++++++++++ clang/lib/Sema/SemaHLSL.cpp | 35 ++++++- .../CodeGenHLSL/builtins/WavePrefixSum.hlsl | 46 +++++++++ .../BuiltIns/WavePrefixSum-errors.hlsl | 28 ++++++ llvm/include/llvm/IR/IntrinsicsDirectX.td | 1 + llvm/include/llvm/IR/IntrinsicsSPIRV.td | 1 + llvm/lib/Target/DirectX/DXIL.td | 18 ++++ llvm/lib/Target/DirectX/DXILShaderFlags.cpp | 2 + .../DirectX/DirectXTargetTransformInfo.cpp | 1 + .../Target/SPIRV/SPIRVInstructionSelector.cpp | 24 +++-- llvm/test/CodeGen/DirectX/WavePrefixSum.ll | 87 ++++++++++++++++ .../SPIRV/hlsl-intrinsics/WavePrefixSum.ll | 41 ++++++++ 15 files changed, 384 insertions(+), 13 deletions(-) create mode 100644 clang/test/CodeGenHLSL/builtins/WavePrefixSum.hlsl create mode 100644 clang/test/SemaHLSL/BuiltIns/WavePrefixSum-errors.hlsl create mode 100644 llvm/test/CodeGen/DirectX/WavePrefixSum.ll create mode 100644 llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WavePrefixSum.ll diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td index d4a3e34a43c53..8a41388bd2244 100644 --- a/clang/include/clang/Basic/Builtins.td +++ b/clang/include/clang/Basic/Builtins.td @@ -5072,6 +5072,12 @@ def HLSLWaveGetLaneCount : LangBuiltin<"HLSL_LANG"> { let Prototype = "unsigned int()"; } +def HLSLWavePrefixSum : LangBuiltin<"HLSL_LANG"> { + let Spellings = ["__builtin_hlsl_wave_prefix_sum"]; + let Attributes = [NoThrow, Const]; + let Prototype = "void(...)"; +} + def HLSLClamp : LangBuiltin<"HLSL_LANG"> { let Spellings = ["__builtin_hlsl_elementwise_clamp"]; let Attributes = [NoThrow, Const, CustomTypeChecking]; diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp b/clang/lib/CodeGen/CGHLSLBuiltins.cpp index b6928ce7d9c44..046cc98f874e6 100644 --- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp +++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp @@ -864,6 +864,13 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID, {OpExpr->getType()}), ArrayRef{OpExpr, OpIndex}, "hlsl.wave.readlane"); } + case Builtin::BI__builtin_hlsl_wave_prefix_sum: { + Value *OpExpr = EmitScalarExpr(E->getArg(0)); + Intrinsic::ID IID = CGM.getHLSLRuntime().getWavePrefixSumIntrinsic(); + return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration( + &CGM.getModule(), IID, {OpExpr->getType()}), + ArrayRef{OpExpr}, "hlsl.wave.prefix.sum"); + } case Builtin::BI__builtin_hlsl_elementwise_sign: { auto *Arg0 = E->getArg(0); Value *Op0 = EmitScalarExpr(Arg0); diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h index 488a322ca7569..7d90bf74df03c 100644 --- a/clang/lib/CodeGen/CGHLSLRuntime.h +++ b/clang/lib/CodeGen/CGHLSLRuntime.h @@ -144,6 +144,7 @@ class CGHLSLRuntime { GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane) GENERATE_HLSL_INTRINSIC_FUNCTION(WaveGetLaneCount, wave_get_lane_count) GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlane) + GENERATE_HLSL_INTRINSIC_FUNCTION(WavePrefixSum, wave_prefix_sum) GENERATE_HLSL_INTRINSIC_FUNCTION(FirstBitUHigh, firstbituhigh) GENERATE_HLSL_INTRINSIC_FUNCTION(FirstBitSHigh, firstbitshigh) GENERATE_HLSL_INTRINSIC_FUNCTION(FirstBitLow, firstbitlow) diff --git a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h index 2e2703de18cb1..d3b9af9695016 100644 --- a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h +++ b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h @@ -2802,6 +2802,105 @@ __attribute__((convergent)) double3 WaveActiveSum(double3); _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum) __attribute__((convergent)) double4 WaveActiveSum(double4); +//===----------------------------------------------------------------------===// +// WavePrefixSum builtins +//===----------------------------------------------------------------------===// + +_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) half WavePrefixSum(half); +_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) half2 WavePrefixSum(half2); +_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) half3 WavePrefixSum(half3); +_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) half4 WavePrefixSum(half4); + +#ifdef __HLSL_ENABLE_16_BIT +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) int16_t WavePrefixSum(int16_t); +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) int16_t2 WavePrefixSum(int16_t2); +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) int16_t3 WavePrefixSum(int16_t3); +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) int16_t4 WavePrefixSum(int16_t4); + +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) uint16_t WavePrefixSum(uint16_t); +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) uint16_t2 WavePrefixSum(uint16_t2); +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) uint16_t3 WavePrefixSum(uint16_t3); +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) uint16_t4 WavePrefixSum(uint16_t4); +#endif + +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) int WavePrefixSum(int); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) int2 WavePrefixSum(int2); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) int3 WavePrefixSum(int3); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) int4 WavePrefixSum(int4); + +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) uint WavePrefixSum(uint); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) uint2 WavePrefixSum(uint2); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) uint3 WavePrefixSum(uint3); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) uint4 WavePrefixSum(uint4); + +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) int64_t WavePrefixSum(int64_t); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) int64_t2 WavePrefixSum(int64_t2); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) int64_t3 WavePrefixSum(int64_t3); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) int64_t4 WavePrefixSum(int64_t4); + +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) uint64_t WavePrefixSum(uint64_t); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) uint64_t2 WavePrefixSum(uint64_t2); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) uint64_t3 WavePrefixSum(uint64_t3); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) uint64_t4 WavePrefixSum(uint64_t4); + +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) float WavePrefixSum(float); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) float2 WavePrefixSum(float2); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) float3 WavePrefixSum(float3); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) float4 WavePrefixSum(float4); + +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) double WavePrefixSum(double); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) double2 WavePrefixSum(double2); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) double3 WavePrefixSum(double3); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum) +__attribute__((convergent)) double4 WavePrefixSum(double4); + //===----------------------------------------------------------------------===// // sign builtins //===----------------------------------------------------------------------===// diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index e95fe16e6cb6c..d8abdb9b75d18 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -2892,10 +2892,13 @@ static bool CheckAnyScalarOrVector(Sema *S, CallExpr *TheCall, return false; } -static bool CheckWaveActive(Sema *S, CallExpr *TheCall) { +// Check that the argument is not a bool or vector +// Returns true on error +static bool CheckNotBoolScalarOrVector(Sema *S, CallExpr *TheCall, + unsigned ArgIndex) { QualType BoolType = S->getASTContext().BoolTy; - assert(TheCall->getNumArgs() >= 1); - QualType ArgType = TheCall->getArg(0)->getType(); + assert(ArgIndex < TheCall->getNumArgs()); + QualType ArgType = TheCall->getArg(ArgIndex)->getType(); auto *VTy = ArgType->getAs(); // is the bool or vector if (S->Context.hasSameUnqualifiedType(ArgType, BoolType) || @@ -2909,6 +2912,18 @@ static bool CheckWaveActive(Sema *S, CallExpr *TheCall) { return false; } +static bool CheckWaveActive(Sema *S, CallExpr *TheCall) { + if (CheckNotBoolScalarOrVector(S, TheCall, 0)) + return true; + return false; +} + +static bool CheckWavePrefix(Sema *S, CallExpr *TheCall) { + if (CheckNotBoolScalarOrVector(S, TheCall, 0)) + return true; + return false; +} + static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) { assert(TheCall->getNumArgs() == 3); Expr *Arg1 = TheCall->getArg(1); @@ -3371,6 +3386,20 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) { return true; break; } + case Builtin::BI__builtin_hlsl_wave_prefix_sum: { + if (SemaRef.checkArgCount(TheCall, 1)) + return true; + + // Ensure input expr type is a scalar/vector and the same as the return type + if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0)) + return true; + if (CheckWavePrefix(&SemaRef, TheCall)) + return true; + ExprResult Expr = TheCall->getArg(0); + QualType ArgTyExpr = Expr.get()->getType(); + TheCall->setType(ArgTyExpr); + break; + } case Builtin::BI__builtin_hlsl_elementwise_splitdouble: { if (SemaRef.checkArgCount(TheCall, 3)) return true; diff --git a/clang/test/CodeGenHLSL/builtins/WavePrefixSum.hlsl b/clang/test/CodeGenHLSL/builtins/WavePrefixSum.hlsl new file mode 100644 index 0000000000000..7004ec0c5c4ac --- /dev/null +++ b/clang/test/CodeGenHLSL/builtins/WavePrefixSum.hlsl @@ -0,0 +1,46 @@ +// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \ +// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \ +// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL +// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \ +// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \ +// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV + +// Test basic lowering to runtime function call. + +// CHECK-LABEL: test_int +int test_int(int expr) { + // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.prefix.sum.i32([[TY]] %[[#]]) + // CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.prefix.sum.i32([[TY]] %[[#]]) + // CHECK: ret [[TY]] %[[RET]] + return WavePrefixSum(expr); +} + +// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.prefix.sum.i32([[TY]]) #[[#attr:]] +// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.prefix.sum.i32([[TY]]) #[[#attr:]] + +// CHECK-LABEL: test_uint64_t +uint64_t test_uint64_t(uint64_t expr) { + // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.prefix.sum.i64([[TY]] %[[#]]) + // CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.prefix.sum.i64([[TY]] %[[#]]) + // CHECK: ret [[TY]] %[[RET]] + return WavePrefixSum(expr); +} + +// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.prefix.sum.i64([[TY]]) #[[#attr:]] +// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.prefix.sum.i64([[TY]]) #[[#attr:]] + +// Test basic lowering to runtime function call with array and float value. + +// CHECK-LABEL: test_floatv4 +float4 test_floatv4(float4 expr) { + // CHECK-SPIRV: %[[RET1:.*]] = call reassoc nnan ninf nsz arcp afn spir_func [[TY1:.*]] @llvm.spv.wave.prefix.sum.v4f32([[TY1]] %[[#]] + // CHECK-DXIL: %[[RET1:.*]] = call reassoc nnan ninf nsz arcp afn [[TY1:.*]] @llvm.dx.wave.prefix.sum.v4f32([[TY1]] %[[#]]) + // CHECK: ret [[TY1]] %[[RET1]] + return WavePrefixSum(expr); +} + +// CHECK-DXIL: declare [[TY1]] @llvm.dx.wave.prefix.sum.v4f32([[TY1]]) #[[#attr]] +// CHECK-SPIRV: declare [[TY1]] @llvm.spv.wave.prefix.sum.v4f32([[TY1]]) #[[#attr]] + +// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}} + diff --git a/clang/test/SemaHLSL/BuiltIns/WavePrefixSum-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/WavePrefixSum-errors.hlsl new file mode 100644 index 0000000000000..1e575c94e67a5 --- /dev/null +++ b/clang/test/SemaHLSL/BuiltIns/WavePrefixSum-errors.hlsl @@ -0,0 +1,28 @@ +// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify + +int test_too_few_arg() { + return __builtin_hlsl_wave_prefix_sum(); + // expected-error@-1 {{too few arguments to function call, expected 1, have 0}} +} + +float2 test_too_many_arg(float2 p0) { + return __builtin_hlsl_wave_prefix_sum(p0, p0); + // expected-error@-1 {{too many arguments to function call, expected 1, have 2}} +} + +bool test_expr_bool_type_check(bool p0) { + return __builtin_hlsl_wave_prefix_sum(p0); + // expected-error@-1 {{invalid operand of type 'bool'}} +} + +bool2 test_expr_bool_vec_type_check(bool2 p0) { + return __builtin_hlsl_wave_prefix_sum(p0); + // expected-error@-1 {{invalid operand of type 'bool2' (aka 'vector')}} +} + +struct S { float f; }; + +S test_expr_struct_type_check(S p0) { + return __builtin_hlsl_wave_prefix_sum(p0); + // expected-error@-1 {{invalid operand of type 'S' where a scalar or vector is required}} +} diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td index d7db935ee07f1..a2de23d51acb5 100644 --- a/llvm/include/llvm/IR/IntrinsicsDirectX.td +++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td @@ -164,6 +164,7 @@ def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrCon def int_dx_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>; def int_dx_wave_get_lane_count : DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent]>; +def int_dx_wave_prefix_sum : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>; def int_dx_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>; def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>], [IntrNoMem]>; def int_dx_splitdouble : DefaultAttrsIntrinsic<[llvm_anyint_ty, LLVMMatchType<0>], diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td index f39c6cda2c579..9c05cc54809b4 100644 --- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td +++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td @@ -129,6 +129,7 @@ def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty] def int_spv_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>; def int_spv_wave_get_lane_count : DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent]>; + def int_spv_wave_prefix_sum : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>; def int_spv_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>; def int_spv_radians : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>; def int_spv_group_memory_barrier_with_group_sync diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td index 67437f6969b27..2dc23a984ca33 100644 --- a/llvm/lib/Target/DirectX/DXIL.td +++ b/llvm/lib/Target/DirectX/DXIL.td @@ -1079,6 +1079,24 @@ def WaveActiveOp : DXILOp<119, waveActiveOp> { let attributes = [Attributes]; } +def WavePrefixOp : DXILOp<121, wavePrefixOp> { + let Doc = "returns partial result of the computation in the corresponding lane"; + let intrinsics = [ + IntrinSelect, IntrinArgI8 + ]> + ]; + + let arguments = [OverloadTy, Int8Ty]; + let result = OverloadTy; + let overloads = [ + Overloads + ]; + let stages = [Stages]; + let attributes = [Attributes]; +} + def LegacyF16ToF32 : DXILOp<131, legacyF16ToF32> { let Doc = "returns the float16 stored in the low-half of the uint converted " "to a float"; diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp index ce6e8121b9d94..8a2ed48b61557 100644 --- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp +++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp @@ -96,6 +96,8 @@ static bool checkWaveOps(Intrinsic::ID IID) { case Intrinsic::dx_wave_reduce_umax: case Intrinsic::dx_wave_reduce_min: case Intrinsic::dx_wave_reduce_umin: + // Wave Prefix Op Variants + case Intrinsic::dx_wave_prefix_sum: return true; } } diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp index 6cacbf6564db2..214a164597dd2 100644 --- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp +++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp @@ -56,6 +56,7 @@ bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable( case Intrinsic::dx_saturate: case Intrinsic::dx_splitdouble: case Intrinsic::dx_wave_readlane: + case Intrinsic::dx_wave_prefix_sum: case Intrinsic::dx_wave_reduce_max: case Intrinsic::dx_wave_reduce_min: case Intrinsic::dx_wave_reduce_sum: diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index fc87288a4a212..68e185cdd301f 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -225,8 +225,8 @@ class SPIRVInstructionSelector : public InstructionSelector { bool selectWaveReduceMin(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, bool IsUnsigned) const; - bool selectWaveReduceSum(Register ResVReg, const SPIRVType *ResType, - MachineInstr &I) const; + bool selectWaveSumWithGroupOp(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I, int64_t GroupOp) const; bool selectConst(Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const; @@ -2452,7 +2452,7 @@ bool SPIRVInstructionSelector::selectWaveReduceMax(Register ResVReg, report_fatal_error("Input Type could not be determined."); SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII); - // Retreive the operation to use based on input type + // Retrieve the operation to use based on input type bool IsFloatTy = GR.isScalarOrVectorOfType(InputRegister, SPIRV::OpTypeFloat); auto IntegerOpcodeType = IsUnsigned ? SPIRV::OpGroupNonUniformUMax : SPIRV::OpGroupNonUniformSMax; @@ -2481,7 +2481,7 @@ bool SPIRVInstructionSelector::selectWaveReduceMin(Register ResVReg, report_fatal_error("Input Type could not be determined."); SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII); - // Retreive the operation to use based on input type + // Retrieve the operation to use based on input type bool IsFloatTy = GR.isScalarOrVectorOfType(InputRegister, SPIRV::OpTypeFloat); auto IntegerOpcodeType = IsUnsigned ? SPIRV::OpGroupNonUniformUMin : SPIRV::OpGroupNonUniformSMin; @@ -2496,9 +2496,9 @@ bool SPIRVInstructionSelector::selectWaveReduceMin(Register ResVReg, .constrainAllUses(TII, TRI, RBI); } -bool SPIRVInstructionSelector::selectWaveReduceSum(Register ResVReg, - const SPIRVType *ResType, - MachineInstr &I) const { +bool SPIRVInstructionSelector::selectWaveSumWithGroupOp( + Register ResVReg, const SPIRVType *ResType, MachineInstr &I, + int64_t GroupOp) const { assert(I.getNumOperands() == 3); assert(I.getOperand(2).isReg()); MachineBasicBlock &BB = *I.getParent(); @@ -2509,7 +2509,7 @@ bool SPIRVInstructionSelector::selectWaveReduceSum(Register ResVReg, report_fatal_error("Input Type could not be determined."); SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII); - // Retreive the operation to use based on input type + // Retrieve the operation to use based on input type bool IsFloatTy = GR.isScalarOrVectorOfType(InputRegister, SPIRV::OpTypeFloat); auto Opcode = IsFloatTy ? SPIRV::OpGroupNonUniformFAdd : SPIRV::OpGroupNonUniformIAdd; @@ -2518,7 +2518,7 @@ bool SPIRVInstructionSelector::selectWaveReduceSum(Register ResVReg, .addUse(GR.getSPIRVTypeID(ResType)) .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII, !STI.isShader())) - .addImm(SPIRV::GroupOperation::Reduce) + .addImm(GroupOp) .addUse(I.getOperand(2).getReg()); } @@ -3485,10 +3485,14 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, case Intrinsic::spv_wave_reduce_min: return selectWaveReduceMin(ResVReg, ResType, I, /*IsUnsigned*/ false); case Intrinsic::spv_wave_reduce_sum: - return selectWaveReduceSum(ResVReg, ResType, I); + return selectWaveSumWithGroupOp(ResVReg, ResType, I, + SPIRV::GroupOperation::Reduce); case Intrinsic::spv_wave_readlane: return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformShuffle); + case Intrinsic::spv_wave_prefix_sum: + return selectWaveSumWithGroupOp(ResVReg, ResType, I, + SPIRV::GroupOperation::ExclusiveScan); case Intrinsic::spv_step: return selectExtInst(ResVReg, ResType, I, CL::step, GL::Step); case Intrinsic::spv_radians: diff --git a/llvm/test/CodeGen/DirectX/WavePrefixSum.ll b/llvm/test/CodeGen/DirectX/WavePrefixSum.ll new file mode 100644 index 0000000000000..2596599798acb --- /dev/null +++ b/llvm/test/CodeGen/DirectX/WavePrefixSum.ll @@ -0,0 +1,87 @@ +; RUN: opt -S -scalarizer -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library < %s | FileCheck %s + +; Test that for scalar values, WavePrefixSum maps down to the DirectX op + +define noundef half @wave_prefix_sum_half(half noundef %expr) { +entry: +; CHECK: call half @dx.op.wavePrefixOp.f16(i32 121, half %expr, i8 0) + %ret = call half @llvm.dx.wave.prefix.sum.f16(half %expr) + ret half %ret +} + +define noundef float @wave_prefix_sum_float(float noundef %expr) { +entry: +; CHECK: call float @dx.op.wavePrefixOp.f32(i32 121, float %expr, i8 0) + %ret = call float @llvm.dx.wave.prefix.sum.f32(float %expr) + ret float %ret +} + +define noundef double @wave_prefix_sum_double(double noundef %expr) { +entry: +; CHECK: call double @dx.op.wavePrefixOp.f64(i32 121, double %expr, i8 0) + %ret = call double @llvm.dx.wave.prefix.sum.f64(double %expr) + ret double %ret +} + +define noundef i16 @wave_prefix_sum_i16(i16 noundef %expr) { +entry: +; CHECK: call i16 @dx.op.wavePrefixOp.i16(i32 121, i16 %expr, i8 0) + %ret = call i16 @llvm.dx.wave.prefix.sum.i16(i16 %expr) + ret i16 %ret +} + +define noundef i32 @wave_prefix_sum_i32(i32 noundef %expr) { +entry: +; CHECK: call i32 @dx.op.wavePrefixOp.i32(i32 121, i32 %expr, i8 0) + %ret = call i32 @llvm.dx.wave.prefix.sum.i32(i32 %expr) + ret i32 %ret +} + +define noundef i64 @wave_prefix_sum_i64(i64 noundef %expr) { +entry: +; CHECK: call i64 @dx.op.wavePrefixOp.i64(i32 121, i64 %expr, i8 0) + %ret = call i64 @llvm.dx.wave.prefix.sum.i64(i64 %expr) + ret i64 %ret +} + +declare half @llvm.dx.wave.prefix.sum.f16(half) +declare float @llvm.dx.wave.prefix.sum.f32(float) +declare double @llvm.dx.wave.prefix.sum.f64(double) + +declare i16 @llvm.dx.wave.prefix.sum.i16(i16) +declare i32 @llvm.dx.wave.prefix.sum.i32(i32) +declare i64 @llvm.dx.wave.prefix.sum.i64(i64) + +; Test that for vector values, WavePrefixSum scalarizes and maps down to the +; DirectX op + +define noundef <2 x half> @wave_prefix_sum_v2half(<2 x half> noundef %expr) { +entry: +; CHECK: call half @dx.op.wavePrefixOp.f16(i32 121, half %expr.i0, i8 0) +; CHECK: call half @dx.op.wavePrefixOp.f16(i32 121, half %expr.i1, i8 0) + %ret = call <2 x half> @llvm.dx.wave.prefix.sum.v2f16(<2 x half> %expr) + ret <2 x half> %ret +} + +define noundef <3 x i32> @wave_prefix_sum_v3i32(<3 x i32> noundef %expr) { +entry: +; CHECK: call i32 @dx.op.wavePrefixOp.i32(i32 121, i32 %expr.i0, i8 0) +; CHECK: call i32 @dx.op.wavePrefixOp.i32(i32 121, i32 %expr.i1, i8 0) +; CHECK: call i32 @dx.op.wavePrefixOp.i32(i32 121, i32 %expr.i2, i8 0) + %ret = call <3 x i32> @llvm.dx.wave.prefix.sum.v3i32(<3 x i32> %expr) + ret <3 x i32> %ret +} + +define noundef <4 x double> @wave_prefix_sum_v4f64(<4 x double> noundef %expr) { +entry: +; CHECK: call double @dx.op.wavePrefixOp.f64(i32 121, double %expr.i0, i8 0) +; CHECK: call double @dx.op.wavePrefixOp.f64(i32 121, double %expr.i1, i8 0) +; CHECK: call double @dx.op.wavePrefixOp.f64(i32 121, double %expr.i2, i8 0) +; CHECK: call double @dx.op.wavePrefixOp.f64(i32 121, double %expr.i3, i8 0) + %ret = call <4 x double> @llvm.dx.wave.prefix.sum.v464(<4 x double> %expr) + ret <4 x double> %ret +} + +declare <2 x half> @llvm.dx.wave.prefix.sum.v2f16(<2 x half>) +declare <3 x i32> @llvm.dx.wave.prefix.sum.v3i32(<3 x i32>) +declare <4 x double> @llvm.dx.wave.prefix.sum.v4f64(<4 x double>) diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WavePrefixSum.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WavePrefixSum.ll new file mode 100644 index 0000000000000..5fb82fd9ebf19 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WavePrefixSum.ll @@ -0,0 +1,41 @@ +; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-vulkan-unknown %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-vulkan-unknown %s -o - -filetype=obj | spirv-val %} + +; Test lowering to spir-v backend for various types and scalar/vector + +; CHECK-DAG: %[[#f16:]] = OpTypeFloat 16 +; CHECK-DAG: %[[#f32:]] = OpTypeFloat 32 +; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0 +; CHECK-DAG: %[[#v4_half:]] = OpTypeVector %[[#f16]] 4 +; CHECK-DAG: %[[#scope:]] = OpConstant %[[#uint]] 3 + +; CHECK-LABEL: Begin function test_float +; CHECK: %[[#fexpr:]] = OpFunctionParameter %[[#f32]] +define float @test_float(float %fexpr) { +entry: +; CHECK: %[[#fret:]] = OpGroupNonUniformFAdd %[[#f32]] %[[#scope]] ExclusiveScan %[[#fexpr]] + %0 = call float @llvm.spv.wave.prefix.sum.f32(float %fexpr) + ret float %0 +} + +; CHECK-LABEL: Begin function test_int +; CHECK: %[[#iexpr:]] = OpFunctionParameter %[[#uint]] +define i32 @test_int(i32 %iexpr) { +entry: +; CHECK: %[[#iret:]] = OpGroupNonUniformIAdd %[[#uint]] %[[#scope]] ExclusiveScan %[[#iexpr]] + %0 = call i32 @llvm.spv.wave.prefix.sum.i32(i32 %iexpr) + ret i32 %0 +} + +; CHECK-LABEL: Begin function test_vhalf +; CHECK: %[[#vbexpr:]] = OpFunctionParameter %[[#v4_half]] +define <4 x half> @test_vhalf(<4 x half> %vbexpr) { +entry: +; CHECK: %[[#vhalfret:]] = OpGroupNonUniformFAdd %[[#v4_half]] %[[#scope]] ExclusiveScan %[[#vbexpr]] + %0 = call <4 x half> @llvm.spv.wave.prefix.sum.v4half(<4 x half> %vbexpr) + ret <4 x half> %0 +} + +declare float @llvm.spv.wave.prefix.sum.f32(float) +declare i32 @llvm.spv.wave.prefix.sum.i32(i32) +declare <4 x half> @llvm.spv.wave.prefix.sum.v4half(<4 x half>) From 1edeecc8faa0e79f046bfc7504b4764190dcc431 Mon Sep 17 00:00:00 2001 From: Kai Huang Date: Fri, 14 Nov 2025 04:12:34 +0800 Subject: [PATCH 2/2] [HLSL][DXIL][SPIRV] WavePrefixSum: added unsigned DXIL intrinsic variant - Added missing int_dx_wave_prefix_usum to DXIL.td - NFC change to HLSL CodeGen that introduced getUnsignedIntrinsicVariant so it can be used by the new GENERATE_HLSL_INTRINSIC_FUNCTION_SELECT_UNSIGNED macro to pick the correct intrinsic based on the unsigned condition bool. --- clang/lib/CodeGen/CGHLSLBuiltins.cpp | 83 ++++-------------- clang/lib/CodeGen/CGHLSLRuntime.cpp | 23 +++++ clang/lib/CodeGen/CGHLSLRuntime.h | 40 ++++++++- .../CodeGenHLSL/builtins/WavePrefixSum.hlsl | 4 +- llvm/include/llvm/IR/IntrinsicsDirectX.td | 1 + llvm/lib/Target/DirectX/DXIL.td | 12 ++- .../DirectX/DirectXTargetTransformInfo.cpp | 3 +- llvm/test/CodeGen/DirectX/WavePrefixSum.ll | 86 +++++++++++++++---- 8 files changed, 162 insertions(+), 90 deletions(-) diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp b/clang/lib/CodeGen/CGHLSLBuiltins.cpp index 046cc98f874e6..f79433e755f94 100644 --- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp +++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp @@ -240,61 +240,6 @@ static Intrinsic::ID getFirstBitHighIntrinsic(CGHLSLRuntime &RT, QualType QT) { return RT.getFirstBitUHighIntrinsic(); } -// Return wave active sum that corresponds to the QT scalar type -static Intrinsic::ID getWaveActiveSumIntrinsic(llvm::Triple::ArchType Arch, - CGHLSLRuntime &RT, QualType QT) { - switch (Arch) { - case llvm::Triple::spirv: - return Intrinsic::spv_wave_reduce_sum; - case llvm::Triple::dxil: { - if (QT->isUnsignedIntegerType()) - return Intrinsic::dx_wave_reduce_usum; - return Intrinsic::dx_wave_reduce_sum; - } - default: - llvm_unreachable("Intrinsic WaveActiveSum" - " not supported by target architecture"); - } -} - -// Return wave active max that corresponds to the QT scalar type -static Intrinsic::ID getWaveActiveMaxIntrinsic(llvm::Triple::ArchType Arch, - CGHLSLRuntime &RT, QualType QT) { - switch (Arch) { - case llvm::Triple::spirv: - if (QT->isUnsignedIntegerType()) - return Intrinsic::spv_wave_reduce_umax; - return Intrinsic::spv_wave_reduce_max; - case llvm::Triple::dxil: { - if (QT->isUnsignedIntegerType()) - return Intrinsic::dx_wave_reduce_umax; - return Intrinsic::dx_wave_reduce_max; - } - default: - llvm_unreachable("Intrinsic WaveActiveMax" - " not supported by target architecture"); - } -} - -// Return wave active min that corresponds to the QT scalar type -static Intrinsic::ID getWaveActiveMinIntrinsic(llvm::Triple::ArchType Arch, - CGHLSLRuntime &RT, QualType QT) { - switch (Arch) { - case llvm::Triple::spirv: - if (QT->isUnsignedIntegerType()) - return Intrinsic::spv_wave_reduce_umin; - return Intrinsic::spv_wave_reduce_min; - case llvm::Triple::dxil: { - if (QT->isUnsignedIntegerType()) - return Intrinsic::dx_wave_reduce_umin; - return Intrinsic::dx_wave_reduce_min; - } - default: - llvm_unreachable("Intrinsic WaveActiveMin" - " not supported by target architecture"); - } -} - // Returns the mangled name for a builtin function that the SPIR-V backend // will expand into a spec Constant. static std::string getSpecConstantFunctionName(clang::QualType SpecConstantType, @@ -794,33 +739,33 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID, ArrayRef{OpExpr}); } case Builtin::BI__builtin_hlsl_wave_active_sum: { - // Due to the use of variadic arguments, explicitly retreive argument + // Due to the use of variadic arguments, explicitly retrieve argument Value *OpExpr = EmitScalarExpr(E->getArg(0)); - Intrinsic::ID IID = getWaveActiveSumIntrinsic( - getTarget().getTriple().getArch(), CGM.getHLSLRuntime(), - E->getArg(0)->getType()); + QualType QT = E->getArg(0)->getType(); + Intrinsic::ID IID = CGM.getHLSLRuntime().getWaveActiveSumIntrinsic( + QT->isUnsignedIntegerType()); return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration( &CGM.getModule(), IID, {OpExpr->getType()}), ArrayRef{OpExpr}, "hlsl.wave.active.sum"); } case Builtin::BI__builtin_hlsl_wave_active_max: { - // Due to the use of variadic arguments, explicitly retreive argument + // Due to the use of variadic arguments, explicitly retrieve argument Value *OpExpr = EmitScalarExpr(E->getArg(0)); - Intrinsic::ID IID = getWaveActiveMaxIntrinsic( - getTarget().getTriple().getArch(), CGM.getHLSLRuntime(), - E->getArg(0)->getType()); + QualType QT = E->getArg(0)->getType(); + Intrinsic::ID IID = CGM.getHLSLRuntime().getWaveActiveMaxIntrinsic( + QT->isUnsignedIntegerType()); return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration( &CGM.getModule(), IID, {OpExpr->getType()}), ArrayRef{OpExpr}, "hlsl.wave.active.max"); } case Builtin::BI__builtin_hlsl_wave_active_min: { - // Due to the use of variadic arguments, explicitly retreive argument + // Due to the use of variadic arguments, explicitly retrieve argument Value *OpExpr = EmitScalarExpr(E->getArg(0)); - Intrinsic::ID IID = getWaveActiveMinIntrinsic( - getTarget().getTriple().getArch(), CGM.getHLSLRuntime(), - E->getArg(0)->getType()); + QualType QT = E->getArg(0)->getType(); + Intrinsic::ID IID = CGM.getHLSLRuntime().getWaveActiveMinIntrinsic( + QT->isUnsignedIntegerType()); return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration( &CGM.getModule(), IID, {OpExpr->getType()}), @@ -866,7 +811,9 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID, } case Builtin::BI__builtin_hlsl_wave_prefix_sum: { Value *OpExpr = EmitScalarExpr(E->getArg(0)); - Intrinsic::ID IID = CGM.getHLSLRuntime().getWavePrefixSumIntrinsic(); + QualType QT = E->getArg(0)->getType(); + Intrinsic::ID IID = CGM.getHLSLRuntime().getWavePrefixSumIntrinsic( + QT->isUnsignedIntegerType()); return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration( &CGM.getModule(), IID, {OpExpr->getType()}), ArrayRef{OpExpr}, "hlsl.wave.prefix.sum"); diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp index 4bdba9b3da502..f9b1928ac7c45 100644 --- a/clang/lib/CodeGen/CGHLSLRuntime.cpp +++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp @@ -276,6 +276,29 @@ llvm::Triple::ArchType CGHLSLRuntime::getArch() { return CGM.getTarget().getTriple().getArch(); } +llvm::Intrinsic::ID +CGHLSLRuntime::getUnsignedIntrinsicVariant(llvm::Intrinsic::ID IID) { + switch (IID) { + // DXIL intrinsics + case Intrinsic::dx_wave_reduce_sum: + return Intrinsic::dx_wave_reduce_usum; + case Intrinsic::dx_wave_reduce_max: + return Intrinsic::dx_wave_reduce_umax; + case Intrinsic::dx_wave_reduce_min: + return Intrinsic::dx_wave_reduce_umin; + case Intrinsic::dx_wave_prefix_sum: + return Intrinsic::dx_wave_prefix_usum; + + // SPIR-V intrinsics + case Intrinsic::spv_wave_reduce_max: + return Intrinsic::spv_wave_reduce_umax; + case Intrinsic::spv_wave_reduce_min: + return Intrinsic::spv_wave_reduce_umin; + default: + return IID; + } +} + // Emits constant global variables for buffer constants declarations // and creates metadata linking the constant globals with the buffer global. void CGHLSLRuntime::emitBufferGlobalsAndMetadata(const HLSLBufferDecl *BufDecl, diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h index 7d90bf74df03c..671c7d434edbb 100644 --- a/clang/lib/CodeGen/CGHLSLRuntime.h +++ b/clang/lib/CodeGen/CGHLSLRuntime.h @@ -49,6 +49,33 @@ } \ } +// A function generator macro for picking the right intrinsic for the target +// backend given IsUnsigned boolean condition. If IsUnsigned == true, it calls +// getUnsignedIntrinsicVariant(IID) to retrieve the unsigned variant of the +// intrinsic else the regular intrinsic is returned. (NOTE: +// getUnsignedIntrinsicVariant returns IID itself if there is no unsigned +// variant). +#define GENERATE_HLSL_INTRINSIC_FUNCTION_SELECT_UNSIGNED(FunctionName, \ + IntrinsicPostfix) \ + llvm::Intrinsic::ID get##FunctionName##Intrinsic(bool IsUnsigned) { \ + llvm::Triple::ArchType Arch = getArch(); \ + switch (Arch) { \ + case llvm::Triple::dxil: { \ + static constexpr llvm::Intrinsic::ID IID = \ + llvm::Intrinsic::dx_##IntrinsicPostfix; \ + return IsUnsigned ? getUnsignedIntrinsicVariant(IID) : IID; \ + } \ + case llvm::Triple::spirv: { \ + static constexpr llvm::Intrinsic::ID IID = \ + llvm::Intrinsic::spv_##IntrinsicPostfix; \ + return IsUnsigned ? getUnsignedIntrinsicVariant(IID) : IID; \ + } \ + default: \ + llvm_unreachable("Intrinsic " #IntrinsicPostfix \ + " not supported by target architecture"); \ + } \ + } + using ResourceClass = llvm::dxil::ResourceClass; namespace llvm { @@ -141,10 +168,17 @@ class CGHLSLRuntime { GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAllTrue, wave_all) GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAnyTrue, wave_any) GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveCountBits, wave_active_countbits) + GENERATE_HLSL_INTRINSIC_FUNCTION_SELECT_UNSIGNED(WaveActiveSum, + wave_reduce_sum) + GENERATE_HLSL_INTRINSIC_FUNCTION_SELECT_UNSIGNED(WaveActiveMax, + wave_reduce_max) + GENERATE_HLSL_INTRINSIC_FUNCTION_SELECT_UNSIGNED(WaveActiveMin, + wave_reduce_min) GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane) GENERATE_HLSL_INTRINSIC_FUNCTION(WaveGetLaneCount, wave_get_lane_count) GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlane) - GENERATE_HLSL_INTRINSIC_FUNCTION(WavePrefixSum, wave_prefix_sum) + GENERATE_HLSL_INTRINSIC_FUNCTION_SELECT_UNSIGNED(WavePrefixSum, + wave_prefix_sum) GENERATE_HLSL_INTRINSIC_FUNCTION(FirstBitUHigh, firstbituhigh) GENERATE_HLSL_INTRINSIC_FUNCTION(FirstBitSHigh, firstbitshigh) GENERATE_HLSL_INTRINSIC_FUNCTION(FirstBitLow, firstbitlow) @@ -247,6 +281,10 @@ class CGHLSLRuntime { llvm::Triple::ArchType getArch(); + // Returns the unsigned variant of the given intrinsic ID if possible, + // otherwise, the original intrinsic ID is returned. + llvm::Intrinsic::ID getUnsignedIntrinsicVariant(llvm::Intrinsic::ID IID); + llvm::DenseMap LayoutTypes; unsigned SPIRVLastAssignedInputSemanticLocation = 0; }; diff --git a/clang/test/CodeGenHLSL/builtins/WavePrefixSum.hlsl b/clang/test/CodeGenHLSL/builtins/WavePrefixSum.hlsl index 7004ec0c5c4ac..f22aa69ba45d5 100644 --- a/clang/test/CodeGenHLSL/builtins/WavePrefixSum.hlsl +++ b/clang/test/CodeGenHLSL/builtins/WavePrefixSum.hlsl @@ -21,12 +21,12 @@ int test_int(int expr) { // CHECK-LABEL: test_uint64_t uint64_t test_uint64_t(uint64_t expr) { // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.prefix.sum.i64([[TY]] %[[#]]) - // CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.prefix.sum.i64([[TY]] %[[#]]) + // CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.prefix.usum.i64([[TY]] %[[#]]) // CHECK: ret [[TY]] %[[RET]] return WavePrefixSum(expr); } -// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.prefix.sum.i64([[TY]]) #[[#attr:]] +// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.prefix.usum.i64([[TY]]) #[[#attr:]] // CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.prefix.sum.i64([[TY]]) #[[#attr:]] // Test basic lowering to runtime function call with array and float value. diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td index a2de23d51acb5..b613d96275d03 100644 --- a/llvm/include/llvm/IR/IntrinsicsDirectX.td +++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td @@ -165,6 +165,7 @@ def int_dx_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0 def int_dx_wave_get_lane_count : DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent]>; def int_dx_wave_prefix_sum : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>; +def int_dx_wave_prefix_usum : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>; def int_dx_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>; def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>], [IntrNoMem]>; def int_dx_splitdouble : DefaultAttrsIntrinsic<[llvm_anyint_ty, LLVMMatchType<0>], diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td index 2dc23a984ca33..6075c9ba23931 100644 --- a/llvm/lib/Target/DirectX/DXIL.td +++ b/llvm/lib/Target/DirectX/DXIL.td @@ -1084,11 +1084,17 @@ def WavePrefixOp : DXILOp<121, wavePrefixOp> { let intrinsics = [ IntrinSelect, IntrinArgI8 - ]> + IntrinArgIndex<0>, IntrinArgI8, + IntrinArgI8 + ]>, + IntrinSelect, IntrinArgI8, + IntrinArgI8 + ]>, ]; - let arguments = [OverloadTy, Int8Ty]; + let arguments = [OverloadTy, Int8Ty, Int8Ty]; let result = OverloadTy; let overloads = [ Overloads diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp index 214a164597dd2..7ae814648cadc 100644 --- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp +++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp @@ -56,13 +56,14 @@ bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable( case Intrinsic::dx_saturate: case Intrinsic::dx_splitdouble: case Intrinsic::dx_wave_readlane: - case Intrinsic::dx_wave_prefix_sum: case Intrinsic::dx_wave_reduce_max: case Intrinsic::dx_wave_reduce_min: case Intrinsic::dx_wave_reduce_sum: + case Intrinsic::dx_wave_prefix_sum: case Intrinsic::dx_wave_reduce_umax: case Intrinsic::dx_wave_reduce_umin: case Intrinsic::dx_wave_reduce_usum: + case Intrinsic::dx_wave_prefix_usum: case Intrinsic::dx_imad: case Intrinsic::dx_umad: return true; diff --git a/llvm/test/CodeGen/DirectX/WavePrefixSum.ll b/llvm/test/CodeGen/DirectX/WavePrefixSum.ll index 2596599798acb..ed8c2b2b85465 100644 --- a/llvm/test/CodeGen/DirectX/WavePrefixSum.ll +++ b/llvm/test/CodeGen/DirectX/WavePrefixSum.ll @@ -4,46 +4,67 @@ define noundef half @wave_prefix_sum_half(half noundef %expr) { entry: -; CHECK: call half @dx.op.wavePrefixOp.f16(i32 121, half %expr, i8 0) +; CHECK: call half @dx.op.wavePrefixOp.f16(i32 121, half %expr, i8 0, i8 0) %ret = call half @llvm.dx.wave.prefix.sum.f16(half %expr) ret half %ret } define noundef float @wave_prefix_sum_float(float noundef %expr) { entry: -; CHECK: call float @dx.op.wavePrefixOp.f32(i32 121, float %expr, i8 0) +; CHECK: call float @dx.op.wavePrefixOp.f32(i32 121, float %expr, i8 0, i8 0) %ret = call float @llvm.dx.wave.prefix.sum.f32(float %expr) ret float %ret } define noundef double @wave_prefix_sum_double(double noundef %expr) { entry: -; CHECK: call double @dx.op.wavePrefixOp.f64(i32 121, double %expr, i8 0) +; CHECK: call double @dx.op.wavePrefixOp.f64(i32 121, double %expr, i8 0, i8 0) %ret = call double @llvm.dx.wave.prefix.sum.f64(double %expr) ret double %ret } define noundef i16 @wave_prefix_sum_i16(i16 noundef %expr) { entry: -; CHECK: call i16 @dx.op.wavePrefixOp.i16(i32 121, i16 %expr, i8 0) +; CHECK: call i16 @dx.op.wavePrefixOp.i16(i32 121, i16 %expr, i8 0, i8 0) %ret = call i16 @llvm.dx.wave.prefix.sum.i16(i16 %expr) ret i16 %ret } define noundef i32 @wave_prefix_sum_i32(i32 noundef %expr) { entry: -; CHECK: call i32 @dx.op.wavePrefixOp.i32(i32 121, i32 %expr, i8 0) +; CHECK: call i32 @dx.op.wavePrefixOp.i32(i32 121, i32 %expr, i8 0, i8 0) %ret = call i32 @llvm.dx.wave.prefix.sum.i32(i32 %expr) ret i32 %ret } define noundef i64 @wave_prefix_sum_i64(i64 noundef %expr) { entry: -; CHECK: call i64 @dx.op.wavePrefixOp.i64(i32 121, i64 %expr, i8 0) +; CHECK: call i64 @dx.op.wavePrefixOp.i64(i32 121, i64 %expr, i8 0, i8 0) %ret = call i64 @llvm.dx.wave.prefix.sum.i64(i64 %expr) ret i64 %ret } +define noundef i16 @wave_prefix_usum_i16(i16 noundef %expr) { +entry: +; CHECK: call i16 @dx.op.wavePrefixOp.i16(i32 121, i16 %expr, i8 0, i8 1) + %ret = call i16 @llvm.dx.wave.prefix.usum.i16(i16 %expr) + ret i16 %ret +} + +define noundef i32 @wave_prefix_usum_i32(i32 noundef %expr) { +entry: +; CHECK: call i32 @dx.op.wavePrefixOp.i32(i32 121, i32 %expr, i8 0, i8 1) + %ret = call i32 @llvm.dx.wave.prefix.usum.i32(i32 %expr) + ret i32 %ret +} + +define noundef i64 @wave_prefix_usum_i64(i64 noundef %expr) { +entry: +; CHECK: call i64 @dx.op.wavePrefixOp.i64(i32 121, i64 %expr, i8 0, i8 1) + %ret = call i64 @llvm.dx.wave.prefix.usum.i64(i64 %expr) + ret i64 %ret +} + declare half @llvm.dx.wave.prefix.sum.f16(half) declare float @llvm.dx.wave.prefix.sum.f32(float) declare double @llvm.dx.wave.prefix.sum.f64(double) @@ -52,32 +73,36 @@ declare i16 @llvm.dx.wave.prefix.sum.i16(i16) declare i32 @llvm.dx.wave.prefix.sum.i32(i32) declare i64 @llvm.dx.wave.prefix.sum.i64(i64) +declare i16 @llvm.dx.wave.prefix.usum.i16(i16) +declare i32 @llvm.dx.wave.prefix.usum.i32(i32) +declare i64 @llvm.dx.wave.prefix.usum.i64(i64) + ; Test that for vector values, WavePrefixSum scalarizes and maps down to the ; DirectX op define noundef <2 x half> @wave_prefix_sum_v2half(<2 x half> noundef %expr) { entry: -; CHECK: call half @dx.op.wavePrefixOp.f16(i32 121, half %expr.i0, i8 0) -; CHECK: call half @dx.op.wavePrefixOp.f16(i32 121, half %expr.i1, i8 0) +; CHECK: call half @dx.op.wavePrefixOp.f16(i32 121, half %expr.i0, i8 0, i8 0) +; CHECK: call half @dx.op.wavePrefixOp.f16(i32 121, half %expr.i1, i8 0, i8 0) %ret = call <2 x half> @llvm.dx.wave.prefix.sum.v2f16(<2 x half> %expr) ret <2 x half> %ret } define noundef <3 x i32> @wave_prefix_sum_v3i32(<3 x i32> noundef %expr) { entry: -; CHECK: call i32 @dx.op.wavePrefixOp.i32(i32 121, i32 %expr.i0, i8 0) -; CHECK: call i32 @dx.op.wavePrefixOp.i32(i32 121, i32 %expr.i1, i8 0) -; CHECK: call i32 @dx.op.wavePrefixOp.i32(i32 121, i32 %expr.i2, i8 0) +; CHECK: call i32 @dx.op.wavePrefixOp.i32(i32 121, i32 %expr.i0, i8 0, i8 0) +; CHECK: call i32 @dx.op.wavePrefixOp.i32(i32 121, i32 %expr.i1, i8 0, i8 0) +; CHECK: call i32 @dx.op.wavePrefixOp.i32(i32 121, i32 %expr.i2, i8 0, i8 0) %ret = call <3 x i32> @llvm.dx.wave.prefix.sum.v3i32(<3 x i32> %expr) ret <3 x i32> %ret } define noundef <4 x double> @wave_prefix_sum_v4f64(<4 x double> noundef %expr) { entry: -; CHECK: call double @dx.op.wavePrefixOp.f64(i32 121, double %expr.i0, i8 0) -; CHECK: call double @dx.op.wavePrefixOp.f64(i32 121, double %expr.i1, i8 0) -; CHECK: call double @dx.op.wavePrefixOp.f64(i32 121, double %expr.i2, i8 0) -; CHECK: call double @dx.op.wavePrefixOp.f64(i32 121, double %expr.i3, i8 0) +; CHECK: call double @dx.op.wavePrefixOp.f64(i32 121, double %expr.i0, i8 0, i8 0) +; CHECK: call double @dx.op.wavePrefixOp.f64(i32 121, double %expr.i1, i8 0, i8 0) +; CHECK: call double @dx.op.wavePrefixOp.f64(i32 121, double %expr.i2, i8 0, i8 0) +; CHECK: call double @dx.op.wavePrefixOp.f64(i32 121, double %expr.i3, i8 0, i8 0) %ret = call <4 x double> @llvm.dx.wave.prefix.sum.v464(<4 x double> %expr) ret <4 x double> %ret } @@ -85,3 +110,34 @@ entry: declare <2 x half> @llvm.dx.wave.prefix.sum.v2f16(<2 x half>) declare <3 x i32> @llvm.dx.wave.prefix.sum.v3i32(<3 x i32>) declare <4 x double> @llvm.dx.wave.prefix.sum.v4f64(<4 x double>) + +define noundef <2 x i16> @wave_prefix_usum_v2i16(<2 x i16> noundef %expr) { +entry: +; CHECK: call i16 @dx.op.wavePrefixOp.i16(i32 121, i16 %expr.i0, i8 0, i8 1) +; CHECK: call i16 @dx.op.wavePrefixOp.i16(i32 121, i16 %expr.i1, i8 0, i8 1) + %ret = call <2 x i16> @llvm.dx.wave.prefix.usum.v2f16(<2 x i16> %expr) + ret <2 x i16> %ret +} + +define noundef <3 x i32> @wave_prefix_usum_v3i32(<3 x i32> noundef %expr) { +entry: +; CHECK: call i32 @dx.op.wavePrefixOp.i32(i32 121, i32 %expr.i0, i8 0, i8 1) +; CHECK: call i32 @dx.op.wavePrefixOp.i32(i32 121, i32 %expr.i1, i8 0, i8 1) +; CHECK: call i32 @dx.op.wavePrefixOp.i32(i32 121, i32 %expr.i2, i8 0, i8 1) + %ret = call <3 x i32> @llvm.dx.wave.prefix.usum.v3i32(<3 x i32> %expr) + ret <3 x i32> %ret +} + +define noundef <4 x i64> @wave_prefix_usum_v4f64(<4 x i64> noundef %expr) { +entry: +; CHECK: call i64 @dx.op.wavePrefixOp.i64(i32 121, i64 %expr.i0, i8 0, i8 1) +; CHECK: call i64 @dx.op.wavePrefixOp.i64(i32 121, i64 %expr.i1, i8 0, i8 1) +; CHECK: call i64 @dx.op.wavePrefixOp.i64(i32 121, i64 %expr.i2, i8 0, i8 1) +; CHECK: call i64 @dx.op.wavePrefixOp.i64(i32 121, i64 %expr.i3, i8 0, i8 1) + %ret = call <4 x i64> @llvm.dx.wave.prefix.usum.v464(<4 x i64> %expr) + ret <4 x i64> %ret +} + +declare <2 x i16> @llvm.dx.wave.prefix.usum.v2f16(<2 x i16>) +declare <3 x i32> @llvm.dx.wave.prefix.usum.v3i32(<3 x i32>) +declare <4 x i64> @llvm.dx.wave.prefix.usum.v4f64(<4 x i64>)