@@ -82,6 +82,7 @@ enum class EntryPropsTag {
8282 ASStateTag,
8383 WaveSize,
8484 EntryRootSig,
85+ WaveRange = 23 ,
8586};
8687
8788} // namespace
@@ -177,30 +178,32 @@ getTagValueAsMetadata(EntryPropsTag Tag, uint64_t Value, LLVMContext &Ctx) {
177178 case EntryPropsTag::ASStateTag:
178179 case EntryPropsTag::WaveSize:
179180 case EntryPropsTag::EntryRootSig:
181+ case EntryPropsTag::WaveRange:
180182 llvm_unreachable (" NYI: Unhandled entry property tag" );
181183 }
182184 return MDVals;
183185}
184186
185- static MDTuple *
186- getEntryPropAsMetadata ( const EntryProperties &EP, uint64_t EntryShaderFlags,
187- const Triple::EnvironmentType ShaderProfile ) {
187+ static MDTuple *getEntryPropAsMetadata (Module &M, const EntryProperties &EP,
188+ uint64_t EntryShaderFlags,
189+ const ModuleMetadataInfo &MMDI ) {
188190 SmallVector<Metadata *> MDVals;
189191 LLVMContext &Ctx = EP.Entry ->getContext ();
190192 if (EntryShaderFlags != 0 )
191193 MDVals.append (getTagValueAsMetadata (EntryPropsTag::ShaderFlags,
192- EntryShaderFlags , Ctx));
194+ MMDI. ShaderProfile , Ctx));
193195
194196 if (EP.Entry != nullptr ) {
195197 // FIXME: support more props.
196198 // See https://github.com/llvm/llvm-project/issues/57948.
197199 // Add shader kind for lib entries.
198- if (ShaderProfile == Triple::EnvironmentType::Library &&
200+ if (MMDI. ShaderProfile == Triple::EnvironmentType::Library &&
199201 EP.ShaderStage != Triple::EnvironmentType::Library)
200202 MDVals.append (getTagValueAsMetadata (EntryPropsTag::ShaderKind,
201203 getShaderStage (EP.ShaderStage ), Ctx));
202204
203205 if (EP.ShaderStage == Triple::EnvironmentType::Compute) {
206+ // Handle mandatory "hlsl.numthreads"
204207 MDVals.emplace_back (ConstantAsMetadata::get (ConstantInt::get (
205208 Type::getInt32Ty (Ctx), static_cast <int >(EntryPropsTag::NumThreads))));
206209 Metadata *NumThreadVals[] = {ConstantAsMetadata::get (ConstantInt::get (
@@ -210,8 +213,47 @@ getEntryPropAsMetadata(const EntryProperties &EP, uint64_t EntryShaderFlags,
210213 ConstantAsMetadata::get (ConstantInt::get (
211214 Type::getInt32Ty (Ctx), EP.NumThreadsZ ))};
212215 MDVals.emplace_back (MDNode::get (Ctx, NumThreadVals));
216+
217+ // Handle optional "hlsl.wavesize". The fields are optionally represented
218+ // if they are non-zero.
219+ if (EP.WaveSizeMin != 0 ) {
220+ bool IsWaveRange = VersionTuple (6 , 8 ) <= MMDI.ShaderModelVersion ;
221+ bool IsWaveSize =
222+ !IsWaveRange && VersionTuple (6 , 6 ) <= MMDI.ShaderModelVersion ;
223+
224+ if (!IsWaveRange && !IsWaveSize) {
225+ reportError (M, " Shader model 6.6 or greater is required to specify "
226+ " the \" hlsl.wavesize\" function attribute" );
227+ return nullptr ;
228+ }
229+
230+ if (EP.WaveSizeMax && !IsWaveRange) {
231+ reportError (
232+ M, " Shader model 6.8 or greater is required to specify "
233+ " wave size range values of the \" hlsl.wavesize\" function "
234+ " attribute" );
235+ return nullptr ;
236+ }
237+
238+ EntryPropsTag Tag =
239+ IsWaveSize ? EntryPropsTag::WaveSize : EntryPropsTag::WaveRange;
240+ MDVals.emplace_back (ConstantAsMetadata::get (
241+ ConstantInt::get (Type::getInt32Ty (Ctx), static_cast <int >(Tag))));
242+
243+ SmallVector<Metadata *> WaveSizeVals = {ConstantAsMetadata::get (
244+ ConstantInt::get (Type::getInt32Ty (Ctx), EP.WaveSizeMin ))};
245+ if (IsWaveRange) {
246+ WaveSizeVals.push_back (ConstantAsMetadata::get (
247+ ConstantInt::get (Type::getInt32Ty (Ctx), EP.WaveSizeMax )));
248+ WaveSizeVals.push_back (ConstantAsMetadata::get (
249+ ConstantInt::get (Type::getInt32Ty (Ctx), EP.WaveSizePref )));
250+ }
251+
252+ MDVals.emplace_back (MDNode::get (Ctx, WaveSizeVals));
253+ }
213254 }
214255 }
256+
215257 if (MDVals.empty ())
216258 return nullptr ;
217259 return MDNode::get (Ctx, MDVals);
@@ -236,12 +278,11 @@ static MDTuple *constructEntryMetadata(const Function *EntryFn,
236278 return MDNode::get (Ctx, MDVals);
237279}
238280
239- static MDTuple *emitEntryMD (const EntryProperties &EP, MDTuple *Signatures ,
240- MDNode *MDResources,
281+ static MDTuple *emitEntryMD (Module &M, const EntryProperties &EP,
282+ MDTuple *Signatures, MDNode *MDResources,
241283 const uint64_t EntryShaderFlags,
242- const Triple::EnvironmentType ShaderProfile) {
243- MDTuple *Properties =
244- getEntryPropAsMetadata (EP, EntryShaderFlags, ShaderProfile);
284+ const ModuleMetadataInfo &MMDI) {
285+ MDTuple *Properties = getEntryPropAsMetadata (M, EP, EntryShaderFlags, MMDI);
245286 return constructEntryMetadata (EP.Entry , Signatures, MDResources, Properties,
246287 EP.Entry ->getContext ());
247288}
@@ -523,10 +564,8 @@ static void translateGlobalMetadata(Module &M, DXILResourceMap &DRM,
523564 Twine (Triple::getEnvironmentTypeName (MMDI.ShaderProfile ) +
524565 " '" ));
525566 }
526-
527- EntryFnMDNodes.emplace_back (emitEntryMD (EntryProp, Signatures, ResourceMD,
528- EntryShaderFlags,
529- MMDI.ShaderProfile ));
567+ EntryFnMDNodes.emplace_back (emitEntryMD (
568+ M, EntryProp, Signatures, ResourceMD, EntryShaderFlags, MMDI));
530569 }
531570
532571 NamedMDNode *EntryPointsNamedMD =
0 commit comments