Skip to content

Conversation

@kcloudy0717
Copy link

Issue: #99172

  • Implement WavePrefixSum clang builtin
  • Link WavePrefixSum clang builtin with hlsl_intrinsics.h
  • Add sema checks for WavePrefixSum to CheckHLSLBuiltinFunctionCall in SemaChecking.cpp
  • Add codegen for WavePrefixSum to EmitHLSLBuiltinExpr in CGBuiltin.cpp
  • Add codegen tests to clang/test/CodeGenHLSL/builtins/WavePrefixSum.hlsl
  • Add sema tests to clang/test/SemaHLSL/BuiltIns/WavePrefixSum-errors.hlsl
  • Create the int_dx_WavePrefixSum intrinsic in IntrinsicsDirectX.td
  • Create the DXILOpMapping of int_dx_WavePrefixSum to 121 in DXIL.td
  • Create the WavePrefixSum.ll and WavePrefixSum_errors.ll tests in llvm/test/CodeGen/DirectX/
  • Create the int_spv_WavePrefixSum intrinsic in IntrinsicsSPIRV.td
  • In SPIRVInstructionSelector.cpp create the WavePrefixSum lowering and map it to int_spv_WavePrefixSum in SPIRVInstructionSelector::selectIntrinsic.
  • Create SPIR-V backend test case in llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WavePrefixSum.ll

I also added a new macro GENERATE_HLSL_INTRINSIC_FUNCTION_SELECT_UNSIGNED in conjunction with the new function getUnsignedIntrinsicVariant to make selecting unsigned variants of the intrinsic easier. As a result, I was able to replace getWaveActiveSumIntrinsic, getWaveActiveMaxIntrinsic, and getWaveActiveMinIntrinsic using the new macro.

This commit adds WavePrefixSum intrinsic support to HLSL.
- Added missing int_dx_wave_prefix_usum to DXIL.td
- NFC change to HLSL CodeGen that introduced getUnsignedIntrinsicVariant so it can be used by the new GENERATE_HLSL_INTRINSIC_FUNCTION_SELECT_UNSIGNED macro to pick the correct intrinsic based on the unsigned condition bool.
@github-actions
Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot llvmbot added clang Clang issues not falling into any other category backend:X86 clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:headers Headers provided by Clang, e.g. for intrinsics clang:codegen IR generation bugs: mangling, exceptions, etc. backend:DirectX HLSL HLSL Language Support backend:SPIR-V llvm:ir labels Nov 13, 2025
@llvmbot
Copy link
Member

llvmbot commented Nov 13, 2025

@llvm/pr-subscribers-backend-spir-v
@llvm/pr-subscribers-clang-codegen
@llvm/pr-subscribers-clang

@llvm/pr-subscribers-hlsl

Author: Kai (kcloudy0717)

Changes

Issue: #99172

  • Implement WavePrefixSum clang builtin
  • Link WavePrefixSum clang builtin with hlsl_intrinsics.h
  • Add sema checks for WavePrefixSum to CheckHLSLBuiltinFunctionCall in SemaChecking.cpp
  • Add codegen for WavePrefixSum to EmitHLSLBuiltinExpr in CGBuiltin.cpp
  • Add codegen tests to clang/test/CodeGenHLSL/builtins/WavePrefixSum.hlsl
  • Add sema tests to clang/test/SemaHLSL/BuiltIns/WavePrefixSum-errors.hlsl
  • Create the int_dx_WavePrefixSum intrinsic in IntrinsicsDirectX.td
  • Create the DXILOpMapping of int_dx_WavePrefixSum to 121 in DXIL.td
  • Create the WavePrefixSum.ll and WavePrefixSum_errors.ll tests in llvm/test/CodeGen/DirectX/
  • Create the int_spv_WavePrefixSum intrinsic in IntrinsicsSPIRV.td
  • In SPIRVInstructionSelector.cpp create the WavePrefixSum lowering and map it to int_spv_WavePrefixSum in SPIRVInstructionSelector::selectIntrinsic.
  • Create SPIR-V backend test case in llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WavePrefixSum.ll

I also added a new macro GENERATE_HLSL_INTRINSIC_FUNCTION_SELECT_UNSIGNED in conjunction with the new function getUnsignedIntrinsicVariant to make selecting unsigned variants of the intrinsic easier. As a result, I was able to replace getWaveActiveSumIntrinsic, getWaveActiveMaxIntrinsic, and getWaveActiveMinIntrinsic using the new macro.


Patch is 39.12 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/167946.diff

16 Files Affected:

  • (modified) clang/include/clang/Basic/Builtins.td (+6)
  • (modified) clang/lib/CodeGen/CGHLSLBuiltins.cpp (+21-67)
  • (modified) clang/lib/CodeGen/CGHLSLRuntime.cpp (+23)
  • (modified) clang/lib/CodeGen/CGHLSLRuntime.h (+39)
  • (modified) clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h (+99)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+32-3)
  • (added) clang/test/CodeGenHLSL/builtins/WavePrefixSum.hlsl (+46)
  • (added) clang/test/SemaHLSL/BuiltIns/WavePrefixSum-errors.hlsl (+28)
  • (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+2)
  • (modified) llvm/include/llvm/IR/IntrinsicsSPIRV.td (+1)
  • (modified) llvm/lib/Target/DirectX/DXIL.td (+24)
  • (modified) llvm/lib/Target/DirectX/DXILShaderFlags.cpp (+2)
  • (modified) llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp (+2)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+14-10)
  • (added) llvm/test/CodeGen/DirectX/WavePrefixSum.ll (+143)
  • (added) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WavePrefixSum.ll (+41)
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index d4a3e34a43c53..8a41388bd2244 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -5072,6 +5072,12 @@ def HLSLWaveGetLaneCount : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "unsigned int()";
 }
 
+def HLSLWavePrefixSum : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_wave_prefix_sum"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void(...)";
+}
+
 def HLSLClamp : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_elementwise_clamp"];
   let Attributes = [NoThrow, Const, CustomTypeChecking];
diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
index b6928ce7d9c44..f79433e755f94 100644
--- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp
+++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
@@ -240,61 +240,6 @@ static Intrinsic::ID getFirstBitHighIntrinsic(CGHLSLRuntime &RT, QualType QT) {
   return RT.getFirstBitUHighIntrinsic();
 }
 
-// Return wave active sum that corresponds to the QT scalar type
-static Intrinsic::ID getWaveActiveSumIntrinsic(llvm::Triple::ArchType Arch,
-                                               CGHLSLRuntime &RT, QualType QT) {
-  switch (Arch) {
-  case llvm::Triple::spirv:
-    return Intrinsic::spv_wave_reduce_sum;
-  case llvm::Triple::dxil: {
-    if (QT->isUnsignedIntegerType())
-      return Intrinsic::dx_wave_reduce_usum;
-    return Intrinsic::dx_wave_reduce_sum;
-  }
-  default:
-    llvm_unreachable("Intrinsic WaveActiveSum"
-                     " not supported by target architecture");
-  }
-}
-
-// Return wave active max that corresponds to the QT scalar type
-static Intrinsic::ID getWaveActiveMaxIntrinsic(llvm::Triple::ArchType Arch,
-                                               CGHLSLRuntime &RT, QualType QT) {
-  switch (Arch) {
-  case llvm::Triple::spirv:
-    if (QT->isUnsignedIntegerType())
-      return Intrinsic::spv_wave_reduce_umax;
-    return Intrinsic::spv_wave_reduce_max;
-  case llvm::Triple::dxil: {
-    if (QT->isUnsignedIntegerType())
-      return Intrinsic::dx_wave_reduce_umax;
-    return Intrinsic::dx_wave_reduce_max;
-  }
-  default:
-    llvm_unreachable("Intrinsic WaveActiveMax"
-                     " not supported by target architecture");
-  }
-}
-
-// Return wave active min that corresponds to the QT scalar type
-static Intrinsic::ID getWaveActiveMinIntrinsic(llvm::Triple::ArchType Arch,
-                                               CGHLSLRuntime &RT, QualType QT) {
-  switch (Arch) {
-  case llvm::Triple::spirv:
-    if (QT->isUnsignedIntegerType())
-      return Intrinsic::spv_wave_reduce_umin;
-    return Intrinsic::spv_wave_reduce_min;
-  case llvm::Triple::dxil: {
-    if (QT->isUnsignedIntegerType())
-      return Intrinsic::dx_wave_reduce_umin;
-    return Intrinsic::dx_wave_reduce_min;
-  }
-  default:
-    llvm_unreachable("Intrinsic WaveActiveMin"
-                     " not supported by target architecture");
-  }
-}
-
 // Returns the mangled name for a builtin function that the SPIR-V backend
 // will expand into a spec Constant.
 static std::string getSpecConstantFunctionName(clang::QualType SpecConstantType,
@@ -794,33 +739,33 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
         ArrayRef{OpExpr});
   }
   case Builtin::BI__builtin_hlsl_wave_active_sum: {
-    // Due to the use of variadic arguments, explicitly retreive argument
+    // Due to the use of variadic arguments, explicitly retrieve argument
     Value *OpExpr = EmitScalarExpr(E->getArg(0));
-    Intrinsic::ID IID = getWaveActiveSumIntrinsic(
-        getTarget().getTriple().getArch(), CGM.getHLSLRuntime(),
-        E->getArg(0)->getType());
+    QualType QT = E->getArg(0)->getType();
+    Intrinsic::ID IID = CGM.getHLSLRuntime().getWaveActiveSumIntrinsic(
+        QT->isUnsignedIntegerType());
 
     return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration(
                                &CGM.getModule(), IID, {OpExpr->getType()}),
                            ArrayRef{OpExpr}, "hlsl.wave.active.sum");
   }
   case Builtin::BI__builtin_hlsl_wave_active_max: {
-    // Due to the use of variadic arguments, explicitly retreive argument
+    // Due to the use of variadic arguments, explicitly retrieve argument
     Value *OpExpr = EmitScalarExpr(E->getArg(0));
-    Intrinsic::ID IID = getWaveActiveMaxIntrinsic(
-        getTarget().getTriple().getArch(), CGM.getHLSLRuntime(),
-        E->getArg(0)->getType());
+    QualType QT = E->getArg(0)->getType();
+    Intrinsic::ID IID = CGM.getHLSLRuntime().getWaveActiveMaxIntrinsic(
+        QT->isUnsignedIntegerType());
 
     return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration(
                                &CGM.getModule(), IID, {OpExpr->getType()}),
                            ArrayRef{OpExpr}, "hlsl.wave.active.max");
   }
   case Builtin::BI__builtin_hlsl_wave_active_min: {
-    // Due to the use of variadic arguments, explicitly retreive argument
+    // Due to the use of variadic arguments, explicitly retrieve argument
     Value *OpExpr = EmitScalarExpr(E->getArg(0));
-    Intrinsic::ID IID = getWaveActiveMinIntrinsic(
-        getTarget().getTriple().getArch(), CGM.getHLSLRuntime(),
-        E->getArg(0)->getType());
+    QualType QT = E->getArg(0)->getType();
+    Intrinsic::ID IID = CGM.getHLSLRuntime().getWaveActiveMinIntrinsic(
+        QT->isUnsignedIntegerType());
 
     return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration(
                                &CGM.getModule(), IID, {OpExpr->getType()}),
@@ -864,6 +809,15 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
             {OpExpr->getType()}),
         ArrayRef{OpExpr, OpIndex}, "hlsl.wave.readlane");
   }
+  case Builtin::BI__builtin_hlsl_wave_prefix_sum: {
+    Value *OpExpr = EmitScalarExpr(E->getArg(0));
+    QualType QT = E->getArg(0)->getType();
+    Intrinsic::ID IID = CGM.getHLSLRuntime().getWavePrefixSumIntrinsic(
+        QT->isUnsignedIntegerType());
+    return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration(
+                               &CGM.getModule(), IID, {OpExpr->getType()}),
+                           ArrayRef{OpExpr}, "hlsl.wave.prefix.sum");
+  }
   case Builtin::BI__builtin_hlsl_elementwise_sign: {
     auto *Arg0 = E->getArg(0);
     Value *Op0 = EmitScalarExpr(Arg0);
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp
index 4bdba9b3da502..f9b1928ac7c45 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.cpp
+++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp
@@ -276,6 +276,29 @@ llvm::Triple::ArchType CGHLSLRuntime::getArch() {
   return CGM.getTarget().getTriple().getArch();
 }
 
+llvm::Intrinsic::ID
+CGHLSLRuntime::getUnsignedIntrinsicVariant(llvm::Intrinsic::ID IID) {
+  switch (IID) {
+    // DXIL intrinsics
+  case Intrinsic::dx_wave_reduce_sum:
+    return Intrinsic::dx_wave_reduce_usum;
+  case Intrinsic::dx_wave_reduce_max:
+    return Intrinsic::dx_wave_reduce_umax;
+  case Intrinsic::dx_wave_reduce_min:
+    return Intrinsic::dx_wave_reduce_umin;
+  case Intrinsic::dx_wave_prefix_sum:
+    return Intrinsic::dx_wave_prefix_usum;
+
+    // SPIR-V intrinsics
+  case Intrinsic::spv_wave_reduce_max:
+    return Intrinsic::spv_wave_reduce_umax;
+  case Intrinsic::spv_wave_reduce_min:
+    return Intrinsic::spv_wave_reduce_umin;
+  default:
+    return IID;
+  }
+}
+
 // Emits constant global variables for buffer constants declarations
 // and creates metadata linking the constant globals with the buffer global.
 void CGHLSLRuntime::emitBufferGlobalsAndMetadata(const HLSLBufferDecl *BufDecl,
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index 488a322ca7569..671c7d434edbb 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -49,6 +49,33 @@
     }                                                                          \
   }
 
+// A function generator macro for picking the right intrinsic for the target
+// backend given IsUnsigned boolean condition. If IsUnsigned == true, it calls
+// getUnsignedIntrinsicVariant(IID) to retrieve the unsigned variant of the
+// intrinsic else the regular intrinsic is returned. (NOTE:
+// getUnsignedIntrinsicVariant returns IID itself if there is no unsigned
+// variant).
+#define GENERATE_HLSL_INTRINSIC_FUNCTION_SELECT_UNSIGNED(FunctionName,         \
+                                                         IntrinsicPostfix)     \
+  llvm::Intrinsic::ID get##FunctionName##Intrinsic(bool IsUnsigned) {          \
+    llvm::Triple::ArchType Arch = getArch();                                   \
+    switch (Arch) {                                                            \
+    case llvm::Triple::dxil: {                                                 \
+      static constexpr llvm::Intrinsic::ID IID =                               \
+          llvm::Intrinsic::dx_##IntrinsicPostfix;                              \
+      return IsUnsigned ? getUnsignedIntrinsicVariant(IID) : IID;              \
+    }                                                                          \
+    case llvm::Triple::spirv: {                                                \
+      static constexpr llvm::Intrinsic::ID IID =                               \
+          llvm::Intrinsic::spv_##IntrinsicPostfix;                             \
+      return IsUnsigned ? getUnsignedIntrinsicVariant(IID) : IID;              \
+    }                                                                          \
+    default:                                                                   \
+      llvm_unreachable("Intrinsic " #IntrinsicPostfix                          \
+                       " not supported by target architecture");               \
+    }                                                                          \
+  }
+
 using ResourceClass = llvm::dxil::ResourceClass;
 
 namespace llvm {
@@ -141,9 +168,17 @@ class CGHLSLRuntime {
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAllTrue, wave_all)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAnyTrue, wave_any)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveCountBits, wave_active_countbits)
+  GENERATE_HLSL_INTRINSIC_FUNCTION_SELECT_UNSIGNED(WaveActiveSum,
+                                                   wave_reduce_sum)
+  GENERATE_HLSL_INTRINSIC_FUNCTION_SELECT_UNSIGNED(WaveActiveMax,
+                                                   wave_reduce_max)
+  GENERATE_HLSL_INTRINSIC_FUNCTION_SELECT_UNSIGNED(WaveActiveMin,
+                                                   wave_reduce_min)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveGetLaneCount, wave_get_lane_count)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlane)
+  GENERATE_HLSL_INTRINSIC_FUNCTION_SELECT_UNSIGNED(WavePrefixSum,
+                                                   wave_prefix_sum)
   GENERATE_HLSL_INTRINSIC_FUNCTION(FirstBitUHigh, firstbituhigh)
   GENERATE_HLSL_INTRINSIC_FUNCTION(FirstBitSHigh, firstbitshigh)
   GENERATE_HLSL_INTRINSIC_FUNCTION(FirstBitLow, firstbitlow)
@@ -246,6 +281,10 @@ class CGHLSLRuntime {
 
   llvm::Triple::ArchType getArch();
 
+  // Returns the unsigned variant of the given intrinsic ID if possible,
+  // otherwise, the original intrinsic ID is returned.
+  llvm::Intrinsic::ID getUnsignedIntrinsicVariant(llvm::Intrinsic::ID IID);
+
   llvm::DenseMap<const clang::RecordType *, llvm::TargetExtType *> LayoutTypes;
   unsigned SPIRVLastAssignedInputSemanticLocation = 0;
 };
diff --git a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
index 2e2703de18cb1..d3b9af9695016 100644
--- a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
@@ -2802,6 +2802,105 @@ __attribute__((convergent)) double3 WaveActiveSum(double3);
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
 __attribute__((convergent)) double4 WaveActiveSum(double4);
 
+//===----------------------------------------------------------------------===//
+// WavePrefixSum builtins
+//===----------------------------------------------------------------------===//
+
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) half WavePrefixSum(half);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) half2 WavePrefixSum(half2);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) half3 WavePrefixSum(half3);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) half4 WavePrefixSum(half4);
+
+#ifdef __HLSL_ENABLE_16_BIT
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) int16_t WavePrefixSum(int16_t);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) int16_t2 WavePrefixSum(int16_t2);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) int16_t3 WavePrefixSum(int16_t3);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) int16_t4 WavePrefixSum(int16_t4);
+
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) uint16_t WavePrefixSum(uint16_t);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) uint16_t2 WavePrefixSum(uint16_t2);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) uint16_t3 WavePrefixSum(uint16_t3);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) uint16_t4 WavePrefixSum(uint16_t4);
+#endif
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) int WavePrefixSum(int);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) int2 WavePrefixSum(int2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) int3 WavePrefixSum(int3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) int4 WavePrefixSum(int4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) uint WavePrefixSum(uint);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) uint2 WavePrefixSum(uint2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) uint3 WavePrefixSum(uint3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) uint4 WavePrefixSum(uint4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) int64_t WavePrefixSum(int64_t);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) int64_t2 WavePrefixSum(int64_t2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) int64_t3 WavePrefixSum(int64_t3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) int64_t4 WavePrefixSum(int64_t4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) uint64_t WavePrefixSum(uint64_t);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) uint64_t2 WavePrefixSum(uint64_t2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) uint64_t3 WavePrefixSum(uint64_t3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) uint64_t4 WavePrefixSum(uint64_t4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) float WavePrefixSum(float);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) float2 WavePrefixSum(float2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) float3 WavePrefixSum(float3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) float4 WavePrefixSum(float4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) double WavePrefixSum(double);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) double2 WavePrefixSum(double2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) double3 WavePrefixSum(double3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) double4 WavePrefixSum(double4);
+
 //===----------------------------------------------------------------------===//
 // sign builtins
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index e95fe16e6cb6c..d8abdb9b75d18 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2892,10 +2892,13 @@ static bool CheckAnyScalarOrVector(Sema *S, CallExpr *TheCall,
   return false;
 }
 
-static bool CheckWaveActive(Sema *S, CallExpr *TheCall) {
+// Check that the argument is not a bool or vector<bool>
+// Returns true on error
+static bool CheckNotBoolScalarOrVector(Sema *S, CallExpr *TheCall,
+                                       unsigned ArgIndex) {
   QualType BoolType = S->getASTContext().BoolTy;
-  assert(TheCall->getNumArgs() >= 1);
-  QualType ArgType = TheCall->getArg(0)->getType();
+  assert(ArgIndex < TheCall->getNumArgs());
+  QualType ArgType = TheCall->getArg(ArgIndex)->getType();
   auto *VTy = ArgType->getAs<VectorType>();
   // is the bool or vector<bool>
   if (S->Context.hasSameUnqualifiedType(ArgType, BoolType) ||
@@ -2909,6 +2912,18 @@ static bool CheckWaveActive(Sema *S, CallExpr *TheCall) {
   return false;
 }
 
+static bool CheckWaveActive(Sema *S, CallExpr *TheCall) {
+  if (CheckNotBoolScalarOrVector(S, TheCall, 0))
+    return true;
+  return false;
+}
+
+static bool CheckWavePrefix(Sema *S, CallExpr *TheCall) {
+  if (CheckNotBoolScalarOrVector(S, TheCall, 0))
+    return true;
+  return false;
+}
+
 static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) {
   assert(TheCall->getNumArgs() == 3);
   Expr *Arg1 = TheCall->getArg(1);
@@ -3371,6 +3386,20 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
+  case Builtin::BI__builtin_hlsl_wave_prefix_sum: {
+    if (SemaRef.checkArgCount(TheCall, 1))
+      return true;
+
+    // Ensure input expr type is a scalar/vector and the same as the return type
+    if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
+      return true;
+    if (CheckWavePrefix(&SemaRef, TheCall))
+      return true;
+    ExprResult Expr = TheCall->getArg(0);
+    QualType ArgTyExpr = Expr.get()->getType();
+    TheCall->setType(ArgTyExpr);
+    break;
+  }
   case Builtin::BI__builtin_hlsl_elementwise_splitdouble: {
     if (SemaRef.checkArgCount(TheCall, 3))
       return true;
diff --git a/clang/test/CodeGenHLSL/builtins/WavePrefixSum.hlsl b/clang/test/CodeGenHLSL/builtins/WavePrefixSum.hlsl
new file mode 100644
index 0000000000000..f22aa69ba45d5
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/WavePrefixSum.hlsl
@@ -0,0 +1,46 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
+// RUN:   dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN:   FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
+// RUN:   spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN:   FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
+
+// Test basic lowering to runtime function call.
+
+// CHECK-LABEL: test_int
+int test_int(int expr) {
+  // CHECK-SPIRV:  %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.prefix.sum.i32([[TY]] %[[#]])
+  // CHECK-DXIL:  %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.prefix.sum.i32([[TY]] %[[#]])
+  // CHECK:  ret [[TY]] %[[RET]]
+  return WavePrefixSum(expr);
+}
+
+// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.prefix.sum.i32([[TY]]) #[[#attr:]]
+// CHEC...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Nov 13, 2025

@llvm/pr-subscribers-backend-x86

Author: Kai (kcloudy0717)

Changes

Issue: #99172

  • Implement WavePrefixSum clang builtin
  • Link WavePrefixSum clang builtin with hlsl_intrinsics.h
  • Add sema checks for WavePrefixSum to CheckHLSLBuiltinFunctionCall in SemaChecking.cpp
  • Add codegen for WavePrefixSum to EmitHLSLBuiltinExpr in CGBuiltin.cpp
  • Add codegen tests to clang/test/CodeGenHLSL/builtins/WavePrefixSum.hlsl
  • Add sema tests to clang/test/SemaHLSL/BuiltIns/WavePrefixSum-errors.hlsl
  • Create the int_dx_WavePrefixSum intrinsic in IntrinsicsDirectX.td
  • Create the DXILOpMapping of int_dx_WavePrefixSum to 121 in DXIL.td
  • Create the WavePrefixSum.ll and WavePrefixSum_errors.ll tests in llvm/test/CodeGen/DirectX/
  • Create the int_spv_WavePrefixSum intrinsic in IntrinsicsSPIRV.td
  • In SPIRVInstructionSelector.cpp create the WavePrefixSum lowering and map it to int_spv_WavePrefixSum in SPIRVInstructionSelector::selectIntrinsic.
  • Create SPIR-V backend test case in llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WavePrefixSum.ll

I also added a new macro GENERATE_HLSL_INTRINSIC_FUNCTION_SELECT_UNSIGNED in conjunction with the new function getUnsignedIntrinsicVariant to make selecting unsigned variants of the intrinsic easier. As a result, I was able to replace getWaveActiveSumIntrinsic, getWaveActiveMaxIntrinsic, and getWaveActiveMinIntrinsic using the new macro.


Patch is 39.12 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/167946.diff

16 Files Affected:

  • (modified) clang/include/clang/Basic/Builtins.td (+6)
  • (modified) clang/lib/CodeGen/CGHLSLBuiltins.cpp (+21-67)
  • (modified) clang/lib/CodeGen/CGHLSLRuntime.cpp (+23)
  • (modified) clang/lib/CodeGen/CGHLSLRuntime.h (+39)
  • (modified) clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h (+99)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+32-3)
  • (added) clang/test/CodeGenHLSL/builtins/WavePrefixSum.hlsl (+46)
  • (added) clang/test/SemaHLSL/BuiltIns/WavePrefixSum-errors.hlsl (+28)
  • (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+2)
  • (modified) llvm/include/llvm/IR/IntrinsicsSPIRV.td (+1)
  • (modified) llvm/lib/Target/DirectX/DXIL.td (+24)
  • (modified) llvm/lib/Target/DirectX/DXILShaderFlags.cpp (+2)
  • (modified) llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp (+2)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+14-10)
  • (added) llvm/test/CodeGen/DirectX/WavePrefixSum.ll (+143)
  • (added) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WavePrefixSum.ll (+41)
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index d4a3e34a43c53..8a41388bd2244 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -5072,6 +5072,12 @@ def HLSLWaveGetLaneCount : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "unsigned int()";
 }
 
+def HLSLWavePrefixSum : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_wave_prefix_sum"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void(...)";
+}
+
 def HLSLClamp : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_elementwise_clamp"];
   let Attributes = [NoThrow, Const, CustomTypeChecking];
diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
index b6928ce7d9c44..f79433e755f94 100644
--- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp
+++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
@@ -240,61 +240,6 @@ static Intrinsic::ID getFirstBitHighIntrinsic(CGHLSLRuntime &RT, QualType QT) {
   return RT.getFirstBitUHighIntrinsic();
 }
 
-// Return wave active sum that corresponds to the QT scalar type
-static Intrinsic::ID getWaveActiveSumIntrinsic(llvm::Triple::ArchType Arch,
-                                               CGHLSLRuntime &RT, QualType QT) {
-  switch (Arch) {
-  case llvm::Triple::spirv:
-    return Intrinsic::spv_wave_reduce_sum;
-  case llvm::Triple::dxil: {
-    if (QT->isUnsignedIntegerType())
-      return Intrinsic::dx_wave_reduce_usum;
-    return Intrinsic::dx_wave_reduce_sum;
-  }
-  default:
-    llvm_unreachable("Intrinsic WaveActiveSum"
-                     " not supported by target architecture");
-  }
-}
-
-// Return wave active max that corresponds to the QT scalar type
-static Intrinsic::ID getWaveActiveMaxIntrinsic(llvm::Triple::ArchType Arch,
-                                               CGHLSLRuntime &RT, QualType QT) {
-  switch (Arch) {
-  case llvm::Triple::spirv:
-    if (QT->isUnsignedIntegerType())
-      return Intrinsic::spv_wave_reduce_umax;
-    return Intrinsic::spv_wave_reduce_max;
-  case llvm::Triple::dxil: {
-    if (QT->isUnsignedIntegerType())
-      return Intrinsic::dx_wave_reduce_umax;
-    return Intrinsic::dx_wave_reduce_max;
-  }
-  default:
-    llvm_unreachable("Intrinsic WaveActiveMax"
-                     " not supported by target architecture");
-  }
-}
-
-// Return wave active min that corresponds to the QT scalar type
-static Intrinsic::ID getWaveActiveMinIntrinsic(llvm::Triple::ArchType Arch,
-                                               CGHLSLRuntime &RT, QualType QT) {
-  switch (Arch) {
-  case llvm::Triple::spirv:
-    if (QT->isUnsignedIntegerType())
-      return Intrinsic::spv_wave_reduce_umin;
-    return Intrinsic::spv_wave_reduce_min;
-  case llvm::Triple::dxil: {
-    if (QT->isUnsignedIntegerType())
-      return Intrinsic::dx_wave_reduce_umin;
-    return Intrinsic::dx_wave_reduce_min;
-  }
-  default:
-    llvm_unreachable("Intrinsic WaveActiveMin"
-                     " not supported by target architecture");
-  }
-}
-
 // Returns the mangled name for a builtin function that the SPIR-V backend
 // will expand into a spec Constant.
 static std::string getSpecConstantFunctionName(clang::QualType SpecConstantType,
@@ -794,33 +739,33 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
         ArrayRef{OpExpr});
   }
   case Builtin::BI__builtin_hlsl_wave_active_sum: {
-    // Due to the use of variadic arguments, explicitly retreive argument
+    // Due to the use of variadic arguments, explicitly retrieve argument
     Value *OpExpr = EmitScalarExpr(E->getArg(0));
-    Intrinsic::ID IID = getWaveActiveSumIntrinsic(
-        getTarget().getTriple().getArch(), CGM.getHLSLRuntime(),
-        E->getArg(0)->getType());
+    QualType QT = E->getArg(0)->getType();
+    Intrinsic::ID IID = CGM.getHLSLRuntime().getWaveActiveSumIntrinsic(
+        QT->isUnsignedIntegerType());
 
     return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration(
                                &CGM.getModule(), IID, {OpExpr->getType()}),
                            ArrayRef{OpExpr}, "hlsl.wave.active.sum");
   }
   case Builtin::BI__builtin_hlsl_wave_active_max: {
-    // Due to the use of variadic arguments, explicitly retreive argument
+    // Due to the use of variadic arguments, explicitly retrieve argument
     Value *OpExpr = EmitScalarExpr(E->getArg(0));
-    Intrinsic::ID IID = getWaveActiveMaxIntrinsic(
-        getTarget().getTriple().getArch(), CGM.getHLSLRuntime(),
-        E->getArg(0)->getType());
+    QualType QT = E->getArg(0)->getType();
+    Intrinsic::ID IID = CGM.getHLSLRuntime().getWaveActiveMaxIntrinsic(
+        QT->isUnsignedIntegerType());
 
     return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration(
                                &CGM.getModule(), IID, {OpExpr->getType()}),
                            ArrayRef{OpExpr}, "hlsl.wave.active.max");
   }
   case Builtin::BI__builtin_hlsl_wave_active_min: {
-    // Due to the use of variadic arguments, explicitly retreive argument
+    // Due to the use of variadic arguments, explicitly retrieve argument
     Value *OpExpr = EmitScalarExpr(E->getArg(0));
-    Intrinsic::ID IID = getWaveActiveMinIntrinsic(
-        getTarget().getTriple().getArch(), CGM.getHLSLRuntime(),
-        E->getArg(0)->getType());
+    QualType QT = E->getArg(0)->getType();
+    Intrinsic::ID IID = CGM.getHLSLRuntime().getWaveActiveMinIntrinsic(
+        QT->isUnsignedIntegerType());
 
     return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration(
                                &CGM.getModule(), IID, {OpExpr->getType()}),
@@ -864,6 +809,15 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
             {OpExpr->getType()}),
         ArrayRef{OpExpr, OpIndex}, "hlsl.wave.readlane");
   }
+  case Builtin::BI__builtin_hlsl_wave_prefix_sum: {
+    Value *OpExpr = EmitScalarExpr(E->getArg(0));
+    QualType QT = E->getArg(0)->getType();
+    Intrinsic::ID IID = CGM.getHLSLRuntime().getWavePrefixSumIntrinsic(
+        QT->isUnsignedIntegerType());
+    return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration(
+                               &CGM.getModule(), IID, {OpExpr->getType()}),
+                           ArrayRef{OpExpr}, "hlsl.wave.prefix.sum");
+  }
   case Builtin::BI__builtin_hlsl_elementwise_sign: {
     auto *Arg0 = E->getArg(0);
     Value *Op0 = EmitScalarExpr(Arg0);
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp
index 4bdba9b3da502..f9b1928ac7c45 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.cpp
+++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp
@@ -276,6 +276,29 @@ llvm::Triple::ArchType CGHLSLRuntime::getArch() {
   return CGM.getTarget().getTriple().getArch();
 }
 
+llvm::Intrinsic::ID
+CGHLSLRuntime::getUnsignedIntrinsicVariant(llvm::Intrinsic::ID IID) {
+  switch (IID) {
+    // DXIL intrinsics
+  case Intrinsic::dx_wave_reduce_sum:
+    return Intrinsic::dx_wave_reduce_usum;
+  case Intrinsic::dx_wave_reduce_max:
+    return Intrinsic::dx_wave_reduce_umax;
+  case Intrinsic::dx_wave_reduce_min:
+    return Intrinsic::dx_wave_reduce_umin;
+  case Intrinsic::dx_wave_prefix_sum:
+    return Intrinsic::dx_wave_prefix_usum;
+
+    // SPIR-V intrinsics
+  case Intrinsic::spv_wave_reduce_max:
+    return Intrinsic::spv_wave_reduce_umax;
+  case Intrinsic::spv_wave_reduce_min:
+    return Intrinsic::spv_wave_reduce_umin;
+  default:
+    return IID;
+  }
+}
+
 // Emits constant global variables for buffer constants declarations
 // and creates metadata linking the constant globals with the buffer global.
 void CGHLSLRuntime::emitBufferGlobalsAndMetadata(const HLSLBufferDecl *BufDecl,
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index 488a322ca7569..671c7d434edbb 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -49,6 +49,33 @@
     }                                                                          \
   }
 
+// A function generator macro for picking the right intrinsic for the target
+// backend given IsUnsigned boolean condition. If IsUnsigned == true, it calls
+// getUnsignedIntrinsicVariant(IID) to retrieve the unsigned variant of the
+// intrinsic else the regular intrinsic is returned. (NOTE:
+// getUnsignedIntrinsicVariant returns IID itself if there is no unsigned
+// variant).
+#define GENERATE_HLSL_INTRINSIC_FUNCTION_SELECT_UNSIGNED(FunctionName,         \
+                                                         IntrinsicPostfix)     \
+  llvm::Intrinsic::ID get##FunctionName##Intrinsic(bool IsUnsigned) {          \
+    llvm::Triple::ArchType Arch = getArch();                                   \
+    switch (Arch) {                                                            \
+    case llvm::Triple::dxil: {                                                 \
+      static constexpr llvm::Intrinsic::ID IID =                               \
+          llvm::Intrinsic::dx_##IntrinsicPostfix;                              \
+      return IsUnsigned ? getUnsignedIntrinsicVariant(IID) : IID;              \
+    }                                                                          \
+    case llvm::Triple::spirv: {                                                \
+      static constexpr llvm::Intrinsic::ID IID =                               \
+          llvm::Intrinsic::spv_##IntrinsicPostfix;                             \
+      return IsUnsigned ? getUnsignedIntrinsicVariant(IID) : IID;              \
+    }                                                                          \
+    default:                                                                   \
+      llvm_unreachable("Intrinsic " #IntrinsicPostfix                          \
+                       " not supported by target architecture");               \
+    }                                                                          \
+  }
+
 using ResourceClass = llvm::dxil::ResourceClass;
 
 namespace llvm {
@@ -141,9 +168,17 @@ class CGHLSLRuntime {
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAllTrue, wave_all)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAnyTrue, wave_any)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveCountBits, wave_active_countbits)
+  GENERATE_HLSL_INTRINSIC_FUNCTION_SELECT_UNSIGNED(WaveActiveSum,
+                                                   wave_reduce_sum)
+  GENERATE_HLSL_INTRINSIC_FUNCTION_SELECT_UNSIGNED(WaveActiveMax,
+                                                   wave_reduce_max)
+  GENERATE_HLSL_INTRINSIC_FUNCTION_SELECT_UNSIGNED(WaveActiveMin,
+                                                   wave_reduce_min)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveGetLaneCount, wave_get_lane_count)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlane)
+  GENERATE_HLSL_INTRINSIC_FUNCTION_SELECT_UNSIGNED(WavePrefixSum,
+                                                   wave_prefix_sum)
   GENERATE_HLSL_INTRINSIC_FUNCTION(FirstBitUHigh, firstbituhigh)
   GENERATE_HLSL_INTRINSIC_FUNCTION(FirstBitSHigh, firstbitshigh)
   GENERATE_HLSL_INTRINSIC_FUNCTION(FirstBitLow, firstbitlow)
@@ -246,6 +281,10 @@ class CGHLSLRuntime {
 
   llvm::Triple::ArchType getArch();
 
+  // Returns the unsigned variant of the given intrinsic ID if possible,
+  // otherwise, the original intrinsic ID is returned.
+  llvm::Intrinsic::ID getUnsignedIntrinsicVariant(llvm::Intrinsic::ID IID);
+
   llvm::DenseMap<const clang::RecordType *, llvm::TargetExtType *> LayoutTypes;
   unsigned SPIRVLastAssignedInputSemanticLocation = 0;
 };
diff --git a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
index 2e2703de18cb1..d3b9af9695016 100644
--- a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
@@ -2802,6 +2802,105 @@ __attribute__((convergent)) double3 WaveActiveSum(double3);
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
 __attribute__((convergent)) double4 WaveActiveSum(double4);
 
+//===----------------------------------------------------------------------===//
+// WavePrefixSum builtins
+//===----------------------------------------------------------------------===//
+
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) half WavePrefixSum(half);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) half2 WavePrefixSum(half2);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) half3 WavePrefixSum(half3);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) half4 WavePrefixSum(half4);
+
+#ifdef __HLSL_ENABLE_16_BIT
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) int16_t WavePrefixSum(int16_t);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) int16_t2 WavePrefixSum(int16_t2);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) int16_t3 WavePrefixSum(int16_t3);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) int16_t4 WavePrefixSum(int16_t4);
+
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) uint16_t WavePrefixSum(uint16_t);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) uint16_t2 WavePrefixSum(uint16_t2);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) uint16_t3 WavePrefixSum(uint16_t3);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) uint16_t4 WavePrefixSum(uint16_t4);
+#endif
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) int WavePrefixSum(int);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) int2 WavePrefixSum(int2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) int3 WavePrefixSum(int3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) int4 WavePrefixSum(int4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) uint WavePrefixSum(uint);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) uint2 WavePrefixSum(uint2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) uint3 WavePrefixSum(uint3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) uint4 WavePrefixSum(uint4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) int64_t WavePrefixSum(int64_t);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) int64_t2 WavePrefixSum(int64_t2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) int64_t3 WavePrefixSum(int64_t3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) int64_t4 WavePrefixSum(int64_t4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) uint64_t WavePrefixSum(uint64_t);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) uint64_t2 WavePrefixSum(uint64_t2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) uint64_t3 WavePrefixSum(uint64_t3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) uint64_t4 WavePrefixSum(uint64_t4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) float WavePrefixSum(float);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) float2 WavePrefixSum(float2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) float3 WavePrefixSum(float3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) float4 WavePrefixSum(float4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) double WavePrefixSum(double);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) double2 WavePrefixSum(double2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) double3 WavePrefixSum(double3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) double4 WavePrefixSum(double4);
+
 //===----------------------------------------------------------------------===//
 // sign builtins
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index e95fe16e6cb6c..d8abdb9b75d18 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2892,10 +2892,13 @@ static bool CheckAnyScalarOrVector(Sema *S, CallExpr *TheCall,
   return false;
 }
 
-static bool CheckWaveActive(Sema *S, CallExpr *TheCall) {
+// Check that the argument is not a bool or vector<bool>
+// Returns true on error
+static bool CheckNotBoolScalarOrVector(Sema *S, CallExpr *TheCall,
+                                       unsigned ArgIndex) {
   QualType BoolType = S->getASTContext().BoolTy;
-  assert(TheCall->getNumArgs() >= 1);
-  QualType ArgType = TheCall->getArg(0)->getType();
+  assert(ArgIndex < TheCall->getNumArgs());
+  QualType ArgType = TheCall->getArg(ArgIndex)->getType();
   auto *VTy = ArgType->getAs<VectorType>();
   // is the bool or vector<bool>
   if (S->Context.hasSameUnqualifiedType(ArgType, BoolType) ||
@@ -2909,6 +2912,18 @@ static bool CheckWaveActive(Sema *S, CallExpr *TheCall) {
   return false;
 }
 
+static bool CheckWaveActive(Sema *S, CallExpr *TheCall) {
+  if (CheckNotBoolScalarOrVector(S, TheCall, 0))
+    return true;
+  return false;
+}
+
+static bool CheckWavePrefix(Sema *S, CallExpr *TheCall) {
+  if (CheckNotBoolScalarOrVector(S, TheCall, 0))
+    return true;
+  return false;
+}
+
 static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) {
   assert(TheCall->getNumArgs() == 3);
   Expr *Arg1 = TheCall->getArg(1);
@@ -3371,6 +3386,20 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
+  case Builtin::BI__builtin_hlsl_wave_prefix_sum: {
+    if (SemaRef.checkArgCount(TheCall, 1))
+      return true;
+
+    // Ensure input expr type is a scalar/vector and the same as the return type
+    if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
+      return true;
+    if (CheckWavePrefix(&SemaRef, TheCall))
+      return true;
+    ExprResult Expr = TheCall->getArg(0);
+    QualType ArgTyExpr = Expr.get()->getType();
+    TheCall->setType(ArgTyExpr);
+    break;
+  }
   case Builtin::BI__builtin_hlsl_elementwise_splitdouble: {
     if (SemaRef.checkArgCount(TheCall, 3))
       return true;
diff --git a/clang/test/CodeGenHLSL/builtins/WavePrefixSum.hlsl b/clang/test/CodeGenHLSL/builtins/WavePrefixSum.hlsl
new file mode 100644
index 0000000000000..f22aa69ba45d5
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/WavePrefixSum.hlsl
@@ -0,0 +1,46 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
+// RUN:   dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN:   FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
+// RUN:   spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN:   FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
+
+// Test basic lowering to runtime function call.
+
+// CHECK-LABEL: test_int
+int test_int(int expr) {
+  // CHECK-SPIRV:  %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.prefix.sum.i32([[TY]] %[[#]])
+  // CHECK-DXIL:  %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.prefix.sum.i32([[TY]] %[[#]])
+  // CHECK:  ret [[TY]] %[[RET]]
+  return WavePrefixSum(expr);
+}
+
+// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.prefix.sum.i32([[TY]]) #[[#attr:]]
+// CHEC...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Nov 13, 2025

@llvm/pr-subscribers-backend-directx

Author: Kai (kcloudy0717)

Changes

Issue: #99172

  • Implement WavePrefixSum clang builtin
  • Link WavePrefixSum clang builtin with hlsl_intrinsics.h
  • Add sema checks for WavePrefixSum to CheckHLSLBuiltinFunctionCall in SemaChecking.cpp
  • Add codegen for WavePrefixSum to EmitHLSLBuiltinExpr in CGBuiltin.cpp
  • Add codegen tests to clang/test/CodeGenHLSL/builtins/WavePrefixSum.hlsl
  • Add sema tests to clang/test/SemaHLSL/BuiltIns/WavePrefixSum-errors.hlsl
  • Create the int_dx_WavePrefixSum intrinsic in IntrinsicsDirectX.td
  • Create the DXILOpMapping of int_dx_WavePrefixSum to 121 in DXIL.td
  • Create the WavePrefixSum.ll and WavePrefixSum_errors.ll tests in llvm/test/CodeGen/DirectX/
  • Create the int_spv_WavePrefixSum intrinsic in IntrinsicsSPIRV.td
  • In SPIRVInstructionSelector.cpp create the WavePrefixSum lowering and map it to int_spv_WavePrefixSum in SPIRVInstructionSelector::selectIntrinsic.
  • Create SPIR-V backend test case in llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WavePrefixSum.ll

I also added a new macro GENERATE_HLSL_INTRINSIC_FUNCTION_SELECT_UNSIGNED in conjunction with the new function getUnsignedIntrinsicVariant to make selecting unsigned variants of the intrinsic easier. As a result, I was able to replace getWaveActiveSumIntrinsic, getWaveActiveMaxIntrinsic, and getWaveActiveMinIntrinsic using the new macro.


Patch is 39.12 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/167946.diff

16 Files Affected:

  • (modified) clang/include/clang/Basic/Builtins.td (+6)
  • (modified) clang/lib/CodeGen/CGHLSLBuiltins.cpp (+21-67)
  • (modified) clang/lib/CodeGen/CGHLSLRuntime.cpp (+23)
  • (modified) clang/lib/CodeGen/CGHLSLRuntime.h (+39)
  • (modified) clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h (+99)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+32-3)
  • (added) clang/test/CodeGenHLSL/builtins/WavePrefixSum.hlsl (+46)
  • (added) clang/test/SemaHLSL/BuiltIns/WavePrefixSum-errors.hlsl (+28)
  • (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+2)
  • (modified) llvm/include/llvm/IR/IntrinsicsSPIRV.td (+1)
  • (modified) llvm/lib/Target/DirectX/DXIL.td (+24)
  • (modified) llvm/lib/Target/DirectX/DXILShaderFlags.cpp (+2)
  • (modified) llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp (+2)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+14-10)
  • (added) llvm/test/CodeGen/DirectX/WavePrefixSum.ll (+143)
  • (added) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WavePrefixSum.ll (+41)
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index d4a3e34a43c53..8a41388bd2244 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -5072,6 +5072,12 @@ def HLSLWaveGetLaneCount : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "unsigned int()";
 }
 
+def HLSLWavePrefixSum : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_wave_prefix_sum"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void(...)";
+}
+
 def HLSLClamp : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_elementwise_clamp"];
   let Attributes = [NoThrow, Const, CustomTypeChecking];
diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
index b6928ce7d9c44..f79433e755f94 100644
--- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp
+++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
@@ -240,61 +240,6 @@ static Intrinsic::ID getFirstBitHighIntrinsic(CGHLSLRuntime &RT, QualType QT) {
   return RT.getFirstBitUHighIntrinsic();
 }
 
-// Return wave active sum that corresponds to the QT scalar type
-static Intrinsic::ID getWaveActiveSumIntrinsic(llvm::Triple::ArchType Arch,
-                                               CGHLSLRuntime &RT, QualType QT) {
-  switch (Arch) {
-  case llvm::Triple::spirv:
-    return Intrinsic::spv_wave_reduce_sum;
-  case llvm::Triple::dxil: {
-    if (QT->isUnsignedIntegerType())
-      return Intrinsic::dx_wave_reduce_usum;
-    return Intrinsic::dx_wave_reduce_sum;
-  }
-  default:
-    llvm_unreachable("Intrinsic WaveActiveSum"
-                     " not supported by target architecture");
-  }
-}
-
-// Return wave active max that corresponds to the QT scalar type
-static Intrinsic::ID getWaveActiveMaxIntrinsic(llvm::Triple::ArchType Arch,
-                                               CGHLSLRuntime &RT, QualType QT) {
-  switch (Arch) {
-  case llvm::Triple::spirv:
-    if (QT->isUnsignedIntegerType())
-      return Intrinsic::spv_wave_reduce_umax;
-    return Intrinsic::spv_wave_reduce_max;
-  case llvm::Triple::dxil: {
-    if (QT->isUnsignedIntegerType())
-      return Intrinsic::dx_wave_reduce_umax;
-    return Intrinsic::dx_wave_reduce_max;
-  }
-  default:
-    llvm_unreachable("Intrinsic WaveActiveMax"
-                     " not supported by target architecture");
-  }
-}
-
-// Return wave active min that corresponds to the QT scalar type
-static Intrinsic::ID getWaveActiveMinIntrinsic(llvm::Triple::ArchType Arch,
-                                               CGHLSLRuntime &RT, QualType QT) {
-  switch (Arch) {
-  case llvm::Triple::spirv:
-    if (QT->isUnsignedIntegerType())
-      return Intrinsic::spv_wave_reduce_umin;
-    return Intrinsic::spv_wave_reduce_min;
-  case llvm::Triple::dxil: {
-    if (QT->isUnsignedIntegerType())
-      return Intrinsic::dx_wave_reduce_umin;
-    return Intrinsic::dx_wave_reduce_min;
-  }
-  default:
-    llvm_unreachable("Intrinsic WaveActiveMin"
-                     " not supported by target architecture");
-  }
-}
-
 // Returns the mangled name for a builtin function that the SPIR-V backend
 // will expand into a spec Constant.
 static std::string getSpecConstantFunctionName(clang::QualType SpecConstantType,
@@ -794,33 +739,33 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
         ArrayRef{OpExpr});
   }
   case Builtin::BI__builtin_hlsl_wave_active_sum: {
-    // Due to the use of variadic arguments, explicitly retreive argument
+    // Due to the use of variadic arguments, explicitly retrieve argument
     Value *OpExpr = EmitScalarExpr(E->getArg(0));
-    Intrinsic::ID IID = getWaveActiveSumIntrinsic(
-        getTarget().getTriple().getArch(), CGM.getHLSLRuntime(),
-        E->getArg(0)->getType());
+    QualType QT = E->getArg(0)->getType();
+    Intrinsic::ID IID = CGM.getHLSLRuntime().getWaveActiveSumIntrinsic(
+        QT->isUnsignedIntegerType());
 
     return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration(
                                &CGM.getModule(), IID, {OpExpr->getType()}),
                            ArrayRef{OpExpr}, "hlsl.wave.active.sum");
   }
   case Builtin::BI__builtin_hlsl_wave_active_max: {
-    // Due to the use of variadic arguments, explicitly retreive argument
+    // Due to the use of variadic arguments, explicitly retrieve argument
     Value *OpExpr = EmitScalarExpr(E->getArg(0));
-    Intrinsic::ID IID = getWaveActiveMaxIntrinsic(
-        getTarget().getTriple().getArch(), CGM.getHLSLRuntime(),
-        E->getArg(0)->getType());
+    QualType QT = E->getArg(0)->getType();
+    Intrinsic::ID IID = CGM.getHLSLRuntime().getWaveActiveMaxIntrinsic(
+        QT->isUnsignedIntegerType());
 
     return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration(
                                &CGM.getModule(), IID, {OpExpr->getType()}),
                            ArrayRef{OpExpr}, "hlsl.wave.active.max");
   }
   case Builtin::BI__builtin_hlsl_wave_active_min: {
-    // Due to the use of variadic arguments, explicitly retreive argument
+    // Due to the use of variadic arguments, explicitly retrieve argument
     Value *OpExpr = EmitScalarExpr(E->getArg(0));
-    Intrinsic::ID IID = getWaveActiveMinIntrinsic(
-        getTarget().getTriple().getArch(), CGM.getHLSLRuntime(),
-        E->getArg(0)->getType());
+    QualType QT = E->getArg(0)->getType();
+    Intrinsic::ID IID = CGM.getHLSLRuntime().getWaveActiveMinIntrinsic(
+        QT->isUnsignedIntegerType());
 
     return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration(
                                &CGM.getModule(), IID, {OpExpr->getType()}),
@@ -864,6 +809,15 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
             {OpExpr->getType()}),
         ArrayRef{OpExpr, OpIndex}, "hlsl.wave.readlane");
   }
+  case Builtin::BI__builtin_hlsl_wave_prefix_sum: {
+    Value *OpExpr = EmitScalarExpr(E->getArg(0));
+    QualType QT = E->getArg(0)->getType();
+    Intrinsic::ID IID = CGM.getHLSLRuntime().getWavePrefixSumIntrinsic(
+        QT->isUnsignedIntegerType());
+    return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration(
+                               &CGM.getModule(), IID, {OpExpr->getType()}),
+                           ArrayRef{OpExpr}, "hlsl.wave.prefix.sum");
+  }
   case Builtin::BI__builtin_hlsl_elementwise_sign: {
     auto *Arg0 = E->getArg(0);
     Value *Op0 = EmitScalarExpr(Arg0);
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp
index 4bdba9b3da502..f9b1928ac7c45 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.cpp
+++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp
@@ -276,6 +276,29 @@ llvm::Triple::ArchType CGHLSLRuntime::getArch() {
   return CGM.getTarget().getTriple().getArch();
 }
 
+llvm::Intrinsic::ID
+CGHLSLRuntime::getUnsignedIntrinsicVariant(llvm::Intrinsic::ID IID) {
+  switch (IID) {
+    // DXIL intrinsics
+  case Intrinsic::dx_wave_reduce_sum:
+    return Intrinsic::dx_wave_reduce_usum;
+  case Intrinsic::dx_wave_reduce_max:
+    return Intrinsic::dx_wave_reduce_umax;
+  case Intrinsic::dx_wave_reduce_min:
+    return Intrinsic::dx_wave_reduce_umin;
+  case Intrinsic::dx_wave_prefix_sum:
+    return Intrinsic::dx_wave_prefix_usum;
+
+    // SPIR-V intrinsics
+  case Intrinsic::spv_wave_reduce_max:
+    return Intrinsic::spv_wave_reduce_umax;
+  case Intrinsic::spv_wave_reduce_min:
+    return Intrinsic::spv_wave_reduce_umin;
+  default:
+    return IID;
+  }
+}
+
 // Emits constant global variables for buffer constants declarations
 // and creates metadata linking the constant globals with the buffer global.
 void CGHLSLRuntime::emitBufferGlobalsAndMetadata(const HLSLBufferDecl *BufDecl,
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index 488a322ca7569..671c7d434edbb 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -49,6 +49,33 @@
     }                                                                          \
   }
 
+// A function generator macro for picking the right intrinsic for the target
+// backend given IsUnsigned boolean condition. If IsUnsigned == true, it calls
+// getUnsignedIntrinsicVariant(IID) to retrieve the unsigned variant of the
+// intrinsic else the regular intrinsic is returned. (NOTE:
+// getUnsignedIntrinsicVariant returns IID itself if there is no unsigned
+// variant).
+#define GENERATE_HLSL_INTRINSIC_FUNCTION_SELECT_UNSIGNED(FunctionName,         \
+                                                         IntrinsicPostfix)     \
+  llvm::Intrinsic::ID get##FunctionName##Intrinsic(bool IsUnsigned) {          \
+    llvm::Triple::ArchType Arch = getArch();                                   \
+    switch (Arch) {                                                            \
+    case llvm::Triple::dxil: {                                                 \
+      static constexpr llvm::Intrinsic::ID IID =                               \
+          llvm::Intrinsic::dx_##IntrinsicPostfix;                              \
+      return IsUnsigned ? getUnsignedIntrinsicVariant(IID) : IID;              \
+    }                                                                          \
+    case llvm::Triple::spirv: {                                                \
+      static constexpr llvm::Intrinsic::ID IID =                               \
+          llvm::Intrinsic::spv_##IntrinsicPostfix;                             \
+      return IsUnsigned ? getUnsignedIntrinsicVariant(IID) : IID;              \
+    }                                                                          \
+    default:                                                                   \
+      llvm_unreachable("Intrinsic " #IntrinsicPostfix                          \
+                       " not supported by target architecture");               \
+    }                                                                          \
+  }
+
 using ResourceClass = llvm::dxil::ResourceClass;
 
 namespace llvm {
@@ -141,9 +168,17 @@ class CGHLSLRuntime {
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAllTrue, wave_all)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAnyTrue, wave_any)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveCountBits, wave_active_countbits)
+  GENERATE_HLSL_INTRINSIC_FUNCTION_SELECT_UNSIGNED(WaveActiveSum,
+                                                   wave_reduce_sum)
+  GENERATE_HLSL_INTRINSIC_FUNCTION_SELECT_UNSIGNED(WaveActiveMax,
+                                                   wave_reduce_max)
+  GENERATE_HLSL_INTRINSIC_FUNCTION_SELECT_UNSIGNED(WaveActiveMin,
+                                                   wave_reduce_min)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveGetLaneCount, wave_get_lane_count)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlane)
+  GENERATE_HLSL_INTRINSIC_FUNCTION_SELECT_UNSIGNED(WavePrefixSum,
+                                                   wave_prefix_sum)
   GENERATE_HLSL_INTRINSIC_FUNCTION(FirstBitUHigh, firstbituhigh)
   GENERATE_HLSL_INTRINSIC_FUNCTION(FirstBitSHigh, firstbitshigh)
   GENERATE_HLSL_INTRINSIC_FUNCTION(FirstBitLow, firstbitlow)
@@ -246,6 +281,10 @@ class CGHLSLRuntime {
 
   llvm::Triple::ArchType getArch();
 
+  // Returns the unsigned variant of the given intrinsic ID if possible,
+  // otherwise, the original intrinsic ID is returned.
+  llvm::Intrinsic::ID getUnsignedIntrinsicVariant(llvm::Intrinsic::ID IID);
+
   llvm::DenseMap<const clang::RecordType *, llvm::TargetExtType *> LayoutTypes;
   unsigned SPIRVLastAssignedInputSemanticLocation = 0;
 };
diff --git a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
index 2e2703de18cb1..d3b9af9695016 100644
--- a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
@@ -2802,6 +2802,105 @@ __attribute__((convergent)) double3 WaveActiveSum(double3);
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
 __attribute__((convergent)) double4 WaveActiveSum(double4);
 
+//===----------------------------------------------------------------------===//
+// WavePrefixSum builtins
+//===----------------------------------------------------------------------===//
+
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) half WavePrefixSum(half);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) half2 WavePrefixSum(half2);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) half3 WavePrefixSum(half3);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) half4 WavePrefixSum(half4);
+
+#ifdef __HLSL_ENABLE_16_BIT
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) int16_t WavePrefixSum(int16_t);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) int16_t2 WavePrefixSum(int16_t2);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) int16_t3 WavePrefixSum(int16_t3);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) int16_t4 WavePrefixSum(int16_t4);
+
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) uint16_t WavePrefixSum(uint16_t);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) uint16_t2 WavePrefixSum(uint16_t2);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) uint16_t3 WavePrefixSum(uint16_t3);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) uint16_t4 WavePrefixSum(uint16_t4);
+#endif
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) int WavePrefixSum(int);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) int2 WavePrefixSum(int2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) int3 WavePrefixSum(int3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) int4 WavePrefixSum(int4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) uint WavePrefixSum(uint);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) uint2 WavePrefixSum(uint2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) uint3 WavePrefixSum(uint3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) uint4 WavePrefixSum(uint4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) int64_t WavePrefixSum(int64_t);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) int64_t2 WavePrefixSum(int64_t2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) int64_t3 WavePrefixSum(int64_t3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) int64_t4 WavePrefixSum(int64_t4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) uint64_t WavePrefixSum(uint64_t);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) uint64_t2 WavePrefixSum(uint64_t2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) uint64_t3 WavePrefixSum(uint64_t3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) uint64_t4 WavePrefixSum(uint64_t4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) float WavePrefixSum(float);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) float2 WavePrefixSum(float2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) float3 WavePrefixSum(float3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) float4 WavePrefixSum(float4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) double WavePrefixSum(double);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) double2 WavePrefixSum(double2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) double3 WavePrefixSum(double3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
+__attribute__((convergent)) double4 WavePrefixSum(double4);
+
 //===----------------------------------------------------------------------===//
 // sign builtins
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index e95fe16e6cb6c..d8abdb9b75d18 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2892,10 +2892,13 @@ static bool CheckAnyScalarOrVector(Sema *S, CallExpr *TheCall,
   return false;
 }
 
-static bool CheckWaveActive(Sema *S, CallExpr *TheCall) {
+// Check that the argument is not a bool or vector<bool>
+// Returns true on error
+static bool CheckNotBoolScalarOrVector(Sema *S, CallExpr *TheCall,
+                                       unsigned ArgIndex) {
   QualType BoolType = S->getASTContext().BoolTy;
-  assert(TheCall->getNumArgs() >= 1);
-  QualType ArgType = TheCall->getArg(0)->getType();
+  assert(ArgIndex < TheCall->getNumArgs());
+  QualType ArgType = TheCall->getArg(ArgIndex)->getType();
   auto *VTy = ArgType->getAs<VectorType>();
   // is the bool or vector<bool>
   if (S->Context.hasSameUnqualifiedType(ArgType, BoolType) ||
@@ -2909,6 +2912,18 @@ static bool CheckWaveActive(Sema *S, CallExpr *TheCall) {
   return false;
 }
 
+static bool CheckWaveActive(Sema *S, CallExpr *TheCall) {
+  if (CheckNotBoolScalarOrVector(S, TheCall, 0))
+    return true;
+  return false;
+}
+
+static bool CheckWavePrefix(Sema *S, CallExpr *TheCall) {
+  if (CheckNotBoolScalarOrVector(S, TheCall, 0))
+    return true;
+  return false;
+}
+
 static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) {
   assert(TheCall->getNumArgs() == 3);
   Expr *Arg1 = TheCall->getArg(1);
@@ -3371,6 +3386,20 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
+  case Builtin::BI__builtin_hlsl_wave_prefix_sum: {
+    if (SemaRef.checkArgCount(TheCall, 1))
+      return true;
+
+    // Ensure input expr type is a scalar/vector and the same as the return type
+    if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
+      return true;
+    if (CheckWavePrefix(&SemaRef, TheCall))
+      return true;
+    ExprResult Expr = TheCall->getArg(0);
+    QualType ArgTyExpr = Expr.get()->getType();
+    TheCall->setType(ArgTyExpr);
+    break;
+  }
   case Builtin::BI__builtin_hlsl_elementwise_splitdouble: {
     if (SemaRef.checkArgCount(TheCall, 3))
       return true;
diff --git a/clang/test/CodeGenHLSL/builtins/WavePrefixSum.hlsl b/clang/test/CodeGenHLSL/builtins/WavePrefixSum.hlsl
new file mode 100644
index 0000000000000..f22aa69ba45d5
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/WavePrefixSum.hlsl
@@ -0,0 +1,46 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
+// RUN:   dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN:   FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
+// RUN:   spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN:   FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
+
+// Test basic lowering to runtime function call.
+
+// CHECK-LABEL: test_int
+int test_int(int expr) {
+  // CHECK-SPIRV:  %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.prefix.sum.i32([[TY]] %[[#]])
+  // CHECK-DXIL:  %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.prefix.sum.i32([[TY]] %[[#]])
+  // CHECK:  ret [[TY]] %[[RET]]
+  return WavePrefixSum(expr);
+}
+
+// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.prefix.sum.i32([[TY]]) #[[#attr:]]
+// CHEC...
[truncated]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

backend:DirectX backend:SPIR-V backend:X86 clang:codegen IR generation bugs: mangling, exceptions, etc. clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:headers Headers provided by Clang, e.g. for intrinsics clang Clang issues not falling into any other category HLSL HLSL Language Support llvm:ir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants