@@ -295,6 +295,23 @@ static Intrinsic::ID getWaveActiveMinIntrinsic(llvm::Triple::ArchType Arch,
295295 }
296296}
297297
298+ // Return wave active max that corresponds to the QT scalar type
299+ static Intrinsic::ID getWavePrefixSumIntrinsic (llvm::Triple::ArchType Arch,
300+ CGHLSLRuntime &RT, QualType QT) {
301+ switch (Arch) {
302+ case llvm::Triple::spirv:
303+ return Intrinsic::spv_wave_prefix_sum;
304+ case llvm::Triple::dxil: {
305+ if (QT->isUnsignedIntegerType ())
306+ return Intrinsic::dx_wave_prefix_usum;
307+ return Intrinsic::dx_wave_prefix_sum;
308+ }
309+ default :
310+ llvm_unreachable (" Intrinsic WavePrefixSum"
311+ " not supported by target architecture" );
312+ }
313+ }
314+
298315// Returns the mangled name for a builtin function that the SPIR-V backend
299316// will expand into a spec Constant.
300317static std::string getSpecConstantFunctionName (clang::QualType SpecConstantType,
@@ -866,7 +883,9 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
866883 }
867884 case Builtin::BI__builtin_hlsl_wave_prefix_sum: {
868885 Value *OpExpr = EmitScalarExpr (E->getArg (0 ));
869- Intrinsic::ID IID = CGM.getHLSLRuntime ().getWavePrefixSumIntrinsic ();
886+ Intrinsic::ID IID = getWavePrefixSumIntrinsic (
887+ getTarget ().getTriple ().getArch (), CGM.getHLSLRuntime (),
888+ E->getArg (0 )->getType ());
870889 return EmitRuntimeCall (Intrinsic::getOrInsertDeclaration (
871890 &CGM.getModule (), IID, {OpExpr->getType ()}),
872891 ArrayRef{OpExpr}, " hlsl.wave.prefix.sum" );
0 commit comments