Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions clang/include/clang/Basic/Builtins.td
Original file line number Diff line number Diff line change
Expand Up @@ -5072,6 +5072,12 @@ def HLSLWaveGetLaneCount : LangBuiltin<"HLSL_LANG"> {
let Prototype = "unsigned int()";
}

def HLSLWavePrefixSum : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_wave_prefix_sum"];
let Attributes = [NoThrow, Const];
let Prototype = "void(...)";
}

def HLSLClamp : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_elementwise_clamp"];
let Attributes = [NoThrow, Const, CustomTypeChecking];
Expand Down
88 changes: 21 additions & 67 deletions clang/lib/CodeGen/CGHLSLBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,61 +240,6 @@ static Intrinsic::ID getFirstBitHighIntrinsic(CGHLSLRuntime &RT, QualType QT) {
return RT.getFirstBitUHighIntrinsic();
}

// Return wave active sum that corresponds to the QT scalar type
static Intrinsic::ID getWaveActiveSumIntrinsic(llvm::Triple::ArchType Arch,
CGHLSLRuntime &RT, QualType QT) {
switch (Arch) {
case llvm::Triple::spirv:
return Intrinsic::spv_wave_reduce_sum;
case llvm::Triple::dxil: {
if (QT->isUnsignedIntegerType())
return Intrinsic::dx_wave_reduce_usum;
return Intrinsic::dx_wave_reduce_sum;
}
default:
llvm_unreachable("Intrinsic WaveActiveSum"
" not supported by target architecture");
}
}

// Return wave active max that corresponds to the QT scalar type
static Intrinsic::ID getWaveActiveMaxIntrinsic(llvm::Triple::ArchType Arch,
CGHLSLRuntime &RT, QualType QT) {
switch (Arch) {
case llvm::Triple::spirv:
if (QT->isUnsignedIntegerType())
return Intrinsic::spv_wave_reduce_umax;
return Intrinsic::spv_wave_reduce_max;
case llvm::Triple::dxil: {
if (QT->isUnsignedIntegerType())
return Intrinsic::dx_wave_reduce_umax;
return Intrinsic::dx_wave_reduce_max;
}
default:
llvm_unreachable("Intrinsic WaveActiveMax"
" not supported by target architecture");
}
}

// Return wave active min that corresponds to the QT scalar type
static Intrinsic::ID getWaveActiveMinIntrinsic(llvm::Triple::ArchType Arch,
CGHLSLRuntime &RT, QualType QT) {
switch (Arch) {
case llvm::Triple::spirv:
if (QT->isUnsignedIntegerType())
return Intrinsic::spv_wave_reduce_umin;
return Intrinsic::spv_wave_reduce_min;
case llvm::Triple::dxil: {
if (QT->isUnsignedIntegerType())
return Intrinsic::dx_wave_reduce_umin;
return Intrinsic::dx_wave_reduce_min;
}
default:
llvm_unreachable("Intrinsic WaveActiveMin"
" not supported by target architecture");
}
}

// Returns the mangled name for a builtin function that the SPIR-V backend
// will expand into a spec Constant.
static std::string getSpecConstantFunctionName(clang::QualType SpecConstantType,
Expand Down Expand Up @@ -794,33 +739,33 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
ArrayRef{OpExpr});
}
case Builtin::BI__builtin_hlsl_wave_active_sum: {
// Due to the use of variadic arguments, explicitly retreive argument
// Due to the use of variadic arguments, explicitly retrieve argument
Value *OpExpr = EmitScalarExpr(E->getArg(0));
Intrinsic::ID IID = getWaveActiveSumIntrinsic(
getTarget().getTriple().getArch(), CGM.getHLSLRuntime(),
E->getArg(0)->getType());
QualType QT = E->getArg(0)->getType();
Intrinsic::ID IID = CGM.getHLSLRuntime().getWaveActiveSumIntrinsic(
QT->isUnsignedIntegerType());

return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration(
&CGM.getModule(), IID, {OpExpr->getType()}),
ArrayRef{OpExpr}, "hlsl.wave.active.sum");
}
case Builtin::BI__builtin_hlsl_wave_active_max: {
// Due to the use of variadic arguments, explicitly retreive argument
// Due to the use of variadic arguments, explicitly retrieve argument
Value *OpExpr = EmitScalarExpr(E->getArg(0));
Intrinsic::ID IID = getWaveActiveMaxIntrinsic(
getTarget().getTriple().getArch(), CGM.getHLSLRuntime(),
E->getArg(0)->getType());
QualType QT = E->getArg(0)->getType();
Intrinsic::ID IID = CGM.getHLSLRuntime().getWaveActiveMaxIntrinsic(
QT->isUnsignedIntegerType());

return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration(
&CGM.getModule(), IID, {OpExpr->getType()}),
ArrayRef{OpExpr}, "hlsl.wave.active.max");
}
case Builtin::BI__builtin_hlsl_wave_active_min: {
// Due to the use of variadic arguments, explicitly retreive argument
// Due to the use of variadic arguments, explicitly retrieve argument
Value *OpExpr = EmitScalarExpr(E->getArg(0));
Intrinsic::ID IID = getWaveActiveMinIntrinsic(
getTarget().getTriple().getArch(), CGM.getHLSLRuntime(),
E->getArg(0)->getType());
QualType QT = E->getArg(0)->getType();
Intrinsic::ID IID = CGM.getHLSLRuntime().getWaveActiveMinIntrinsic(
QT->isUnsignedIntegerType());

return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration(
&CGM.getModule(), IID, {OpExpr->getType()}),
Expand Down Expand Up @@ -864,6 +809,15 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
{OpExpr->getType()}),
ArrayRef{OpExpr, OpIndex}, "hlsl.wave.readlane");
}
case Builtin::BI__builtin_hlsl_wave_prefix_sum: {
Value *OpExpr = EmitScalarExpr(E->getArg(0));
QualType QT = E->getArg(0)->getType();
Intrinsic::ID IID = CGM.getHLSLRuntime().getWavePrefixSumIntrinsic(
QT->isUnsignedIntegerType());
return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration(
&CGM.getModule(), IID, {OpExpr->getType()}),
ArrayRef{OpExpr}, "hlsl.wave.prefix.sum");
}
case Builtin::BI__builtin_hlsl_elementwise_sign: {
auto *Arg0 = E->getArg(0);
Value *Op0 = EmitScalarExpr(Arg0);
Expand Down
23 changes: 23 additions & 0 deletions clang/lib/CodeGen/CGHLSLRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,29 @@ llvm::Triple::ArchType CGHLSLRuntime::getArch() {
return CGM.getTarget().getTriple().getArch();
}

llvm::Intrinsic::ID
CGHLSLRuntime::getUnsignedIntrinsicVariant(llvm::Intrinsic::ID IID) {
switch (IID) {
// DXIL intrinsics
case Intrinsic::dx_wave_reduce_sum:
return Intrinsic::dx_wave_reduce_usum;
case Intrinsic::dx_wave_reduce_max:
return Intrinsic::dx_wave_reduce_umax;
case Intrinsic::dx_wave_reduce_min:
return Intrinsic::dx_wave_reduce_umin;
case Intrinsic::dx_wave_prefix_sum:
return Intrinsic::dx_wave_prefix_usum;

// SPIR-V intrinsics
case Intrinsic::spv_wave_reduce_max:
return Intrinsic::spv_wave_reduce_umax;
case Intrinsic::spv_wave_reduce_min:
return Intrinsic::spv_wave_reduce_umin;
default:
return IID;
}
}

// Emits constant global variables for buffer constants declarations
// and creates metadata linking the constant globals with the buffer global.
void CGHLSLRuntime::emitBufferGlobalsAndMetadata(const HLSLBufferDecl *BufDecl,
Expand Down
39 changes: 39 additions & 0 deletions clang/lib/CodeGen/CGHLSLRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,33 @@
} \
}

// A function generator macro for picking the right intrinsic for the target
// backend given IsUnsigned boolean condition. If IsUnsigned == true, it calls
// getUnsignedIntrinsicVariant(IID) to retrieve the unsigned variant of the
// intrinsic else the regular intrinsic is returned. (NOTE:
// getUnsignedIntrinsicVariant returns IID itself if there is no unsigned
// variant).
#define GENERATE_HLSL_INTRINSIC_FUNCTION_SELECT_UNSIGNED(FunctionName, \
IntrinsicPostfix) \
llvm::Intrinsic::ID get##FunctionName##Intrinsic(bool IsUnsigned) { \
llvm::Triple::ArchType Arch = getArch(); \
switch (Arch) { \
case llvm::Triple::dxil: { \
static constexpr llvm::Intrinsic::ID IID = \
llvm::Intrinsic::dx_##IntrinsicPostfix; \
return IsUnsigned ? getUnsignedIntrinsicVariant(IID) : IID; \
} \
case llvm::Triple::spirv: { \
static constexpr llvm::Intrinsic::ID IID = \
llvm::Intrinsic::spv_##IntrinsicPostfix; \
return IsUnsigned ? getUnsignedIntrinsicVariant(IID) : IID; \
} \
default: \
llvm_unreachable("Intrinsic " #IntrinsicPostfix \
" not supported by target architecture"); \
} \
}

using ResourceClass = llvm::dxil::ResourceClass;

namespace llvm {
Expand Down Expand Up @@ -141,9 +168,17 @@ class CGHLSLRuntime {
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAllTrue, wave_all)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAnyTrue, wave_any)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveCountBits, wave_active_countbits)
GENERATE_HLSL_INTRINSIC_FUNCTION_SELECT_UNSIGNED(WaveActiveSum,
wave_reduce_sum)
GENERATE_HLSL_INTRINSIC_FUNCTION_SELECT_UNSIGNED(WaveActiveMax,
wave_reduce_max)
GENERATE_HLSL_INTRINSIC_FUNCTION_SELECT_UNSIGNED(WaveActiveMin,
wave_reduce_min)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveGetLaneCount, wave_get_lane_count)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlane)
GENERATE_HLSL_INTRINSIC_FUNCTION_SELECT_UNSIGNED(WavePrefixSum,
wave_prefix_sum)
GENERATE_HLSL_INTRINSIC_FUNCTION(FirstBitUHigh, firstbituhigh)
GENERATE_HLSL_INTRINSIC_FUNCTION(FirstBitSHigh, firstbitshigh)
GENERATE_HLSL_INTRINSIC_FUNCTION(FirstBitLow, firstbitlow)
Expand Down Expand Up @@ -246,6 +281,10 @@ class CGHLSLRuntime {

llvm::Triple::ArchType getArch();

// Returns the unsigned variant of the given intrinsic ID if possible,
// otherwise, the original intrinsic ID is returned.
llvm::Intrinsic::ID getUnsignedIntrinsicVariant(llvm::Intrinsic::ID IID);

llvm::DenseMap<const clang::RecordType *, llvm::TargetExtType *> LayoutTypes;
unsigned SPIRVLastAssignedInputSemanticLocation = 0;
};
Expand Down
99 changes: 99 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -2802,6 +2802,105 @@ __attribute__((convergent)) double3 WaveActiveSum(double3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
__attribute__((convergent)) double4 WaveActiveSum(double4);

//===----------------------------------------------------------------------===//
// WavePrefixSum builtins
//===----------------------------------------------------------------------===//

_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) half WavePrefixSum(half);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) half2 WavePrefixSum(half2);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) half3 WavePrefixSum(half3);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) half4 WavePrefixSum(half4);

#ifdef __HLSL_ENABLE_16_BIT
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) int16_t WavePrefixSum(int16_t);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) int16_t2 WavePrefixSum(int16_t2);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) int16_t3 WavePrefixSum(int16_t3);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) int16_t4 WavePrefixSum(int16_t4);

_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) uint16_t WavePrefixSum(uint16_t);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) uint16_t2 WavePrefixSum(uint16_t2);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) uint16_t3 WavePrefixSum(uint16_t3);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) uint16_t4 WavePrefixSum(uint16_t4);
#endif

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) int WavePrefixSum(int);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) int2 WavePrefixSum(int2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) int3 WavePrefixSum(int3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) int4 WavePrefixSum(int4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) uint WavePrefixSum(uint);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) uint2 WavePrefixSum(uint2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) uint3 WavePrefixSum(uint3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) uint4 WavePrefixSum(uint4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) int64_t WavePrefixSum(int64_t);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) int64_t2 WavePrefixSum(int64_t2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) int64_t3 WavePrefixSum(int64_t3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) int64_t4 WavePrefixSum(int64_t4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) uint64_t WavePrefixSum(uint64_t);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) uint64_t2 WavePrefixSum(uint64_t2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) uint64_t3 WavePrefixSum(uint64_t3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) uint64_t4 WavePrefixSum(uint64_t4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) float WavePrefixSum(float);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) float2 WavePrefixSum(float2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) float3 WavePrefixSum(float3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) float4 WavePrefixSum(float4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) double WavePrefixSum(double);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) double2 WavePrefixSum(double2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) double3 WavePrefixSum(double3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_sum)
__attribute__((convergent)) double4 WavePrefixSum(double4);

//===----------------------------------------------------------------------===//
// sign builtins
//===----------------------------------------------------------------------===//
Expand Down
35 changes: 32 additions & 3 deletions clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2892,10 +2892,13 @@ static bool CheckAnyScalarOrVector(Sema *S, CallExpr *TheCall,
return false;
}

static bool CheckWaveActive(Sema *S, CallExpr *TheCall) {
// Check that the argument is not a bool or vector<bool>
// Returns true on error
static bool CheckNotBoolScalarOrVector(Sema *S, CallExpr *TheCall,
unsigned ArgIndex) {
QualType BoolType = S->getASTContext().BoolTy;
assert(TheCall->getNumArgs() >= 1);
QualType ArgType = TheCall->getArg(0)->getType();
assert(ArgIndex < TheCall->getNumArgs());
QualType ArgType = TheCall->getArg(ArgIndex)->getType();
auto *VTy = ArgType->getAs<VectorType>();
// is the bool or vector<bool>
if (S->Context.hasSameUnqualifiedType(ArgType, BoolType) ||
Expand All @@ -2909,6 +2912,18 @@ static bool CheckWaveActive(Sema *S, CallExpr *TheCall) {
return false;
}

static bool CheckWaveActive(Sema *S, CallExpr *TheCall) {
if (CheckNotBoolScalarOrVector(S, TheCall, 0))
return true;
return false;
}

static bool CheckWavePrefix(Sema *S, CallExpr *TheCall) {
if (CheckNotBoolScalarOrVector(S, TheCall, 0))
return true;
return false;
}

static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) {
assert(TheCall->getNumArgs() == 3);
Expr *Arg1 = TheCall->getArg(1);
Expand Down Expand Up @@ -3371,6 +3386,20 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
break;
}
case Builtin::BI__builtin_hlsl_wave_prefix_sum: {
if (SemaRef.checkArgCount(TheCall, 1))
return true;

// Ensure input expr type is a scalar/vector and the same as the return type
if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
return true;
if (CheckWavePrefix(&SemaRef, TheCall))
return true;
ExprResult Expr = TheCall->getArg(0);
QualType ArgTyExpr = Expr.get()->getType();
TheCall->setType(ArgTyExpr);
break;
}
case Builtin::BI__builtin_hlsl_elementwise_splitdouble: {
if (SemaRef.checkArgCount(TheCall, 3))
return true;
Expand Down
Loading