2222#include " llvm/ADT/StringExtras.h"
2323#include " llvm/ADT/StringRef.h"
2424#include " llvm/Support/Casting.h"
25+ #include " llvm/Support/DXILABI.h"
2526#include " llvm/Support/ErrorHandling.h"
2627#include " llvm/TargetParser/Triple.h"
2728#include < iterator>
@@ -153,6 +154,25 @@ HLSLNumThreadsAttr *SemaHLSL::mergeNumThreadsAttr(Decl *D,
153154 HLSLNumThreadsAttr (getASTContext (), AL, X, Y, Z);
154155}
155156
157+ HLSLWaveSizeAttr *SemaHLSL::mergeWaveSizeAttr (Decl *D,
158+ const AttributeCommonInfo &AL,
159+ int Min, int Max, int Preferred,
160+ int SpelledArgsCount) {
161+ if (HLSLWaveSizeAttr *WS = D->getAttr <HLSLWaveSizeAttr>()) {
162+ if (WS->getMin () != Min || WS->getMax () != Max ||
163+ WS->getPreferred () != Preferred ||
164+ WS->getSpelledArgsCount () != SpelledArgsCount) {
165+ Diag (WS->getLocation (), diag::err_hlsl_attribute_param_mismatch) << AL;
166+ Diag (AL.getLoc (), diag::note_conflicting_attribute);
167+ }
168+ return nullptr ;
169+ }
170+ HLSLWaveSizeAttr *Result = ::new (getASTContext ())
171+ HLSLWaveSizeAttr (getASTContext (), AL, Min, Max, Preferred);
172+ Result->setSpelledArgsCount (SpelledArgsCount);
173+ return Result;
174+ }
175+
156176HLSLShaderAttr *
157177SemaHLSL::mergeShaderAttr (Decl *D, const AttributeCommonInfo &AL,
158178 llvm::Triple::EnvironmentType ShaderType) {
@@ -224,7 +244,8 @@ void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
224244 const auto *ShaderAttr = FD->getAttr <HLSLShaderAttr>();
225245 assert (ShaderAttr && " Entry point has no shader attribute" );
226246 llvm::Triple::EnvironmentType ST = ShaderAttr->getType ();
227-
247+ auto &TargetInfo = getASTContext ().getTargetInfo ();
248+ VersionTuple Ver = TargetInfo.getTriple ().getOSVersion ();
228249 switch (ST) {
229250 case llvm::Triple::Pixel:
230251 case llvm::Triple::Vertex:
@@ -244,6 +265,13 @@ void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
244265 llvm::Triple::Mesh});
245266 FD->setInvalidDecl ();
246267 }
268+ if (const auto *WS = FD->getAttr <HLSLWaveSizeAttr>()) {
269+ DiagnoseAttrStageMismatch (WS, ST,
270+ {llvm::Triple::Compute,
271+ llvm::Triple::Amplification,
272+ llvm::Triple::Mesh});
273+ FD->setInvalidDecl ();
274+ }
247275 break ;
248276
249277 case llvm::Triple::Compute:
@@ -254,6 +282,19 @@ void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
254282 << llvm::Triple::getEnvironmentTypeName (ST);
255283 FD->setInvalidDecl ();
256284 }
285+ if (const auto *WS = FD->getAttr <HLSLWaveSizeAttr>()) {
286+ if (Ver < VersionTuple (6 , 6 )) {
287+ Diag (WS->getLocation (), diag::err_hlsl_attribute_in_wrong_shader_model)
288+ << WS << " 6.6" ;
289+ FD->setInvalidDecl ();
290+ } else if (WS->getSpelledArgsCount () > 1 && Ver < VersionTuple (6 , 8 )) {
291+ Diag (
292+ WS->getLocation (),
293+ diag::err_hlsl_attribute_number_arguments_insufficient_shader_model)
294+ << WS << WS->getSpelledArgsCount () << " 6.8" ;
295+ FD->setInvalidDecl ();
296+ }
297+ }
257298 break ;
258299 default :
259300 llvm_unreachable (" Unhandled environment in triple" );
@@ -357,6 +398,74 @@ void SemaHLSL::handleNumThreadsAttr(Decl *D, const ParsedAttr &AL) {
357398 D->addAttr (NewAttr);
358399}
359400
401+ static bool isValidWaveSizeValue (unsigned Value) {
402+ return llvm::isPowerOf2_32 (Value) && Value >= 4 && Value <= 128 ;
403+ }
404+
405+ void SemaHLSL::handleWaveSizeAttr (Decl *D, const ParsedAttr &AL) {
406+ // validate that the wavesize argument is a power of 2 between 4 and 128
407+ // inclusive
408+ unsigned SpelledArgsCount = AL.getNumArgs ();
409+ if (SpelledArgsCount == 0 || SpelledArgsCount > 3 )
410+ return ;
411+
412+ uint32_t Min;
413+ if (!SemaRef.checkUInt32Argument (AL, AL.getArgAsExpr (0 ), Min))
414+ return ;
415+
416+ uint32_t Max = 0 ;
417+ if (SpelledArgsCount > 1 &&
418+ !SemaRef.checkUInt32Argument (AL, AL.getArgAsExpr (1 ), Max))
419+ return ;
420+
421+ uint32_t Preferred = 0 ;
422+ if (SpelledArgsCount > 2 &&
423+ !SemaRef.checkUInt32Argument (AL, AL.getArgAsExpr (2 ), Preferred))
424+ return ;
425+
426+ if (SpelledArgsCount > 2 ) {
427+ if (!isValidWaveSizeValue (Preferred)) {
428+ Diag (AL.getArgAsExpr (2 )->getExprLoc (),
429+ diag::err_attribute_power_of_two_in_range)
430+ << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize
431+ << Preferred;
432+ return ;
433+ }
434+ // Preferred not in range.
435+ if (Preferred < Min || Preferred > Max) {
436+ Diag (AL.getArgAsExpr (2 )->getExprLoc (),
437+ diag::err_attribute_power_of_two_in_range)
438+ << AL << Min << Max << Preferred;
439+ return ;
440+ }
441+ } else if (SpelledArgsCount > 1 ) {
442+ if (!isValidWaveSizeValue (Max)) {
443+ Diag (AL.getArgAsExpr (1 )->getExprLoc (),
444+ diag::err_attribute_power_of_two_in_range)
445+ << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize << Max;
446+ return ;
447+ }
448+ if (Max < Min) {
449+ Diag (AL.getLoc (), diag::err_attribute_argument_invalid) << AL << 1 ;
450+ return ;
451+ } else if (Max == Min) {
452+ Diag (AL.getLoc (), diag::warn_attr_min_eq_max) << AL;
453+ }
454+ } else {
455+ if (!isValidWaveSizeValue (Min)) {
456+ Diag (AL.getArgAsExpr (0 )->getExprLoc (),
457+ diag::err_attribute_power_of_two_in_range)
458+ << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize << Min;
459+ return ;
460+ }
461+ }
462+
463+ HLSLWaveSizeAttr *NewAttr =
464+ mergeWaveSizeAttr (D, AL, Min, Max, Preferred, SpelledArgsCount);
465+ if (NewAttr)
466+ D->addAttr (NewAttr);
467+ }
468+
360469static bool isLegalTypeForHLSLSV_DispatchThreadID (QualType T) {
361470 if (!T->hasUnsignedIntegerRepresentation ())
362471 return false ;
0 commit comments