@@ -240,6 +240,61 @@ static Intrinsic::ID getFirstBitHighIntrinsic(CGHLSLRuntime &RT, QualType QT) {
240240 return RT.getFirstBitUHighIntrinsic ();
241241}
242242
243+ // Return wave active sum that corresponds to the QT scalar type
244+ static Intrinsic::ID getWaveActiveSumIntrinsic (llvm::Triple::ArchType Arch,
245+ CGHLSLRuntime &RT, QualType QT) {
246+ switch (Arch) {
247+ case llvm::Triple::spirv:
248+ return Intrinsic::spv_wave_reduce_sum;
249+ case llvm::Triple::dxil: {
250+ if (QT->isUnsignedIntegerType ())
251+ return Intrinsic::dx_wave_reduce_usum;
252+ return Intrinsic::dx_wave_reduce_sum;
253+ }
254+ default :
255+ llvm_unreachable (" Intrinsic WaveActiveSum"
256+ " not supported by target architecture" );
257+ }
258+ }
259+
260+ // Return wave active max that corresponds to the QT scalar type
261+ static Intrinsic::ID getWaveActiveMaxIntrinsic (llvm::Triple::ArchType Arch,
262+ CGHLSLRuntime &RT, QualType QT) {
263+ switch (Arch) {
264+ case llvm::Triple::spirv:
265+ if (QT->isUnsignedIntegerType ())
266+ return Intrinsic::spv_wave_reduce_umax;
267+ return Intrinsic::spv_wave_reduce_max;
268+ case llvm::Triple::dxil: {
269+ if (QT->isUnsignedIntegerType ())
270+ return Intrinsic::dx_wave_reduce_umax;
271+ return Intrinsic::dx_wave_reduce_max;
272+ }
273+ default :
274+ llvm_unreachable (" Intrinsic WaveActiveMax"
275+ " not supported by target architecture" );
276+ }
277+ }
278+
279+ // Return wave active min that corresponds to the QT scalar type
280+ static Intrinsic::ID getWaveActiveMinIntrinsic (llvm::Triple::ArchType Arch,
281+ CGHLSLRuntime &RT, QualType QT) {
282+ switch (Arch) {
283+ case llvm::Triple::spirv:
284+ if (QT->isUnsignedIntegerType ())
285+ return Intrinsic::spv_wave_reduce_umin;
286+ return Intrinsic::spv_wave_reduce_min;
287+ case llvm::Triple::dxil: {
288+ if (QT->isUnsignedIntegerType ())
289+ return Intrinsic::dx_wave_reduce_umin;
290+ return Intrinsic::dx_wave_reduce_min;
291+ }
292+ default :
293+ llvm_unreachable (" Intrinsic WaveActiveMin"
294+ " not supported by target architecture" );
295+ }
296+ }
297+
243298// Returns the mangled name for a builtin function that the SPIR-V backend
244299// will expand into a spec Constant.
245300static std::string getSpecConstantFunctionName (clang::QualType SpecConstantType,
@@ -739,33 +794,33 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
739794 ArrayRef{OpExpr});
740795 }
741796 case Builtin::BI__builtin_hlsl_wave_active_sum: {
742- // Due to the use of variadic arguments, explicitly retrieve argument
797+ // Due to the use of variadic arguments, explicitly retreive argument
743798 Value *OpExpr = EmitScalarExpr (E->getArg (0 ));
744- QualType QT = E-> getArg ( 0 )-> getType ();
745- Intrinsic::ID IID = CGM. getHLSLRuntime ().getWaveActiveSumIntrinsic (
746- QT-> isUnsignedIntegerType ());
799+ Intrinsic::ID IID = getWaveActiveSumIntrinsic (
800+ getTarget (). getTriple ().getArch (), CGM. getHLSLRuntime (),
801+ E-> getArg ( 0 )-> getType ());
747802
748803 return EmitRuntimeCall (Intrinsic::getOrInsertDeclaration (
749804 &CGM.getModule (), IID, {OpExpr->getType ()}),
750805 ArrayRef{OpExpr}, " hlsl.wave.active.sum" );
751806 }
752807 case Builtin::BI__builtin_hlsl_wave_active_max: {
753- // Due to the use of variadic arguments, explicitly retrieve argument
808+ // Due to the use of variadic arguments, explicitly retreive argument
754809 Value *OpExpr = EmitScalarExpr (E->getArg (0 ));
755- QualType QT = E-> getArg ( 0 )-> getType ();
756- Intrinsic::ID IID = CGM. getHLSLRuntime ().getWaveActiveMaxIntrinsic (
757- QT-> isUnsignedIntegerType ());
810+ Intrinsic::ID IID = getWaveActiveMaxIntrinsic (
811+ getTarget (). getTriple ().getArch (), CGM. getHLSLRuntime (),
812+ E-> getArg ( 0 )-> getType ());
758813
759814 return EmitRuntimeCall (Intrinsic::getOrInsertDeclaration (
760815 &CGM.getModule (), IID, {OpExpr->getType ()}),
761816 ArrayRef{OpExpr}, " hlsl.wave.active.max" );
762817 }
763818 case Builtin::BI__builtin_hlsl_wave_active_min: {
764- // Due to the use of variadic arguments, explicitly retrieve argument
819+ // Due to the use of variadic arguments, explicitly retreive argument
765820 Value *OpExpr = EmitScalarExpr (E->getArg (0 ));
766- QualType QT = E-> getArg ( 0 )-> getType ();
767- Intrinsic::ID IID = CGM. getHLSLRuntime ().getWaveActiveMinIntrinsic (
768- QT-> isUnsignedIntegerType ());
821+ Intrinsic::ID IID = getWaveActiveMinIntrinsic (
822+ getTarget (). getTriple ().getArch (), CGM. getHLSLRuntime (),
823+ E-> getArg ( 0 )-> getType ());
769824
770825 return EmitRuntimeCall (Intrinsic::getOrInsertDeclaration (
771826 &CGM.getModule (), IID, {OpExpr->getType ()}),
@@ -811,9 +866,7 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
811866 }
812867 case Builtin::BI__builtin_hlsl_wave_prefix_sum: {
813868 Value *OpExpr = EmitScalarExpr (E->getArg (0 ));
814- QualType QT = E->getArg (0 )->getType ();
815- Intrinsic::ID IID = CGM.getHLSLRuntime ().getWavePrefixSumIntrinsic (
816- QT->isUnsignedIntegerType ());
869+ Intrinsic::ID IID = CGM.getHLSLRuntime ().getWavePrefixSumIntrinsic ();
817870 return EmitRuntimeCall (Intrinsic::getOrInsertDeclaration (
818871 &CGM.getModule (), IID, {OpExpr->getType ()}),
819872 ArrayRef{OpExpr}, " hlsl.wave.prefix.sum" );
0 commit comments