@@ -82,6 +82,7 @@ enum class EntryPropsTag {
8282 ASStateTag,
8383 WaveSize,
8484 EntryRootSig,
85+ WaveRange = 23 ,
8586};
8687
8788} // namespace
@@ -177,14 +178,15 @@ getTagValueAsMetadata(EntryPropsTag Tag, uint64_t Value, LLVMContext &Ctx) {
177178 case EntryPropsTag::ASStateTag:
178179 case EntryPropsTag::WaveSize:
179180 case EntryPropsTag::EntryRootSig:
181+ case EntryPropsTag::WaveRange:
180182 llvm_unreachable (" NYI: Unhandled entry property tag" );
181183 }
182184 return MDVals;
183185}
184186
185- static MDTuple *
186- getEntryPropAsMetadata ( const EntryProperties &EP, uint64_t EntryShaderFlags,
187- const Triple::EnvironmentType ShaderProfile ) {
187+ static MDTuple *getEntryPropAsMetadata (Module &M, const EntryProperties &EP,
188+ uint64_t EntryShaderFlags,
189+ const ModuleMetadataInfo &MMDI ) {
188190 SmallVector<Metadata *> MDVals;
189191 LLVMContext &Ctx = EP.Entry ->getContext ();
190192 if (EntryShaderFlags != 0 )
@@ -195,12 +197,13 @@ getEntryPropAsMetadata(const EntryProperties &EP, uint64_t EntryShaderFlags,
195197 // FIXME: support more props.
196198 // See https://github.com/llvm/llvm-project/issues/57948.
197199 // Add shader kind for lib entries.
198- if (ShaderProfile == Triple::EnvironmentType::Library &&
200+ if (MMDI. ShaderProfile == Triple::EnvironmentType::Library &&
199201 EP.ShaderStage != Triple::EnvironmentType::Library)
200202 MDVals.append (getTagValueAsMetadata (EntryPropsTag::ShaderKind,
201203 getShaderStage (EP.ShaderStage ), Ctx));
202204
203205 if (EP.ShaderStage == Triple::EnvironmentType::Compute) {
206+ // Handle mandatory "hlsl.numthreads"
204207 MDVals.emplace_back (ConstantAsMetadata::get (ConstantInt::get (
205208 Type::getInt32Ty (Ctx), static_cast <int >(EntryPropsTag::NumThreads))));
206209 Metadata *NumThreadVals[] = {ConstantAsMetadata::get (ConstantInt::get (
@@ -210,8 +213,48 @@ getEntryPropAsMetadata(const EntryProperties &EP, uint64_t EntryShaderFlags,
210213 ConstantAsMetadata::get (ConstantInt::get (
211214 Type::getInt32Ty (Ctx), EP.NumThreadsZ ))};
212215 MDVals.emplace_back (MDNode::get (Ctx, NumThreadVals));
216+
217+ // Handle optional "hlsl.wavesize". The fields are optionally represented
218+ // if they are non-zero.
219+ if (EP.WaveSizeMin != 0 ) {
220+ bool IsWaveRange = VersionTuple (6 , 8 ) <= MMDI.ShaderModelVersion ;
221+ bool IsWaveSize =
222+ !IsWaveRange && VersionTuple (6 , 6 ) <= MMDI.ShaderModelVersion ;
223+
224+ if (!IsWaveRange && !IsWaveSize) {
225+ reportError (M, " Shader model 6.6 or greater is required to specify "
226+ " the \" hlsl.wavesize\" function attribute" );
227+ return nullptr ;
228+ }
229+
230+ // A range is being specified if EP.WaveSizeMax != 0
231+ if (EP.WaveSizeMax && !IsWaveRange) {
232+ reportError (
233+ M, " Shader model 6.8 or greater is required to specify "
234+ " wave size range values of the \" hlsl.wavesize\" function "
235+ " attribute" );
236+ return nullptr ;
237+ }
238+
239+ EntryPropsTag Tag =
240+ IsWaveSize ? EntryPropsTag::WaveSize : EntryPropsTag::WaveRange;
241+ MDVals.emplace_back (ConstantAsMetadata::get (
242+ ConstantInt::get (Type::getInt32Ty (Ctx), static_cast <int >(Tag))));
243+
244+ SmallVector<Metadata *> WaveSizeVals = {ConstantAsMetadata::get (
245+ ConstantInt::get (Type::getInt32Ty (Ctx), EP.WaveSizeMin ))};
246+ if (IsWaveRange) {
247+ WaveSizeVals.push_back (ConstantAsMetadata::get (
248+ ConstantInt::get (Type::getInt32Ty (Ctx), EP.WaveSizeMax )));
249+ WaveSizeVals.push_back (ConstantAsMetadata::get (
250+ ConstantInt::get (Type::getInt32Ty (Ctx), EP.WaveSizePref )));
251+ }
252+
253+ MDVals.emplace_back (MDNode::get (Ctx, WaveSizeVals));
254+ }
213255 }
214256 }
257+
215258 if (MDVals.empty ())
216259 return nullptr ;
217260 return MDNode::get (Ctx, MDVals);
@@ -236,12 +279,11 @@ static MDTuple *constructEntryMetadata(const Function *EntryFn,
236279 return MDNode::get (Ctx, MDVals);
237280}
238281
239- static MDTuple *emitEntryMD (const EntryProperties &EP, MDTuple *Signatures ,
240- MDNode *MDResources,
282+ static MDTuple *emitEntryMD (Module &M, const EntryProperties &EP,
283+ MDTuple *Signatures, MDNode *MDResources,
241284 const uint64_t EntryShaderFlags,
242- const Triple::EnvironmentType ShaderProfile) {
243- MDTuple *Properties =
244- getEntryPropAsMetadata (EP, EntryShaderFlags, ShaderProfile);
285+ const ModuleMetadataInfo &MMDI) {
286+ MDTuple *Properties = getEntryPropAsMetadata (M, EP, EntryShaderFlags, MMDI);
245287 return constructEntryMetadata (EP.Entry , Signatures, MDResources, Properties,
246288 EP.Entry ->getContext ());
247289}
@@ -523,10 +565,8 @@ static void translateGlobalMetadata(Module &M, DXILResourceMap &DRM,
523565 Twine (Triple::getEnvironmentTypeName (MMDI.ShaderProfile ) +
524566 " '" ));
525567 }
526-
527- EntryFnMDNodes.emplace_back (emitEntryMD (EntryProp, Signatures, ResourceMD,
528- EntryShaderFlags,
529- MMDI.ShaderProfile ));
568+ EntryFnMDNodes.emplace_back (emitEntryMD (
569+ M, EntryProp, Signatures, ResourceMD, EntryShaderFlags, MMDI));
530570 }
531571
532572 NamedMDNode *EntryPointsNamedMD =
0 commit comments