Skip to content

Commit ad49111

Browse files
authored
[HLSL][DirectX] Add support for rootsig as a target environment (llvm#156373)
This pr implements support for a root signature as a target, as specified [here](https://github.com/llvm/wg-hlsl/blob/main/proposals/0029-root-signature-driver-options.md#target-root-signature-version). This is implemented in the following steps: 1. Add `rootsignature` as a shader model environment type and define `rootsig` as a `target_profile`. Only valid as versions 1.0 and 1.1 2. Updates `HLSLFrontendAction` to invoke a special handling of constructing the `ASTContext` if we are considering an `hlsl` file and with a `rootsignature` target 3. Defines the special handling to minimally instantiate the `Parser` and `Sema` to insert the `RootSignatureDecl` 4. Updates `CGHLSLRuntime` to emit the constructed root signature decl as part of `dx.rootsignatures` with a `null` entry function 5. Updates `DXILRootSignature` to handle emitting a root signature without an entry function 6. Updates `ToolChains/HLSL` to invoke `only-section=RTS0` to strip any other generated information Resolves: llvm#150286. ##### Implementation Considerations Ideally we could invoke this as part of `clang-dxc` without the need of a source file. However, the initialization of the `Parser` and `Lexer` becomes quite complicated to handle this. Technically, we could avoid generating any of the extra information that is removed in step 6. However, it seems better to re-use the logic in `llvm-objcopy` without any need for additional custom logic in `DXILRootSignature`.
1 parent 81a4fcb commit ad49111

File tree

24 files changed

+276
-54
lines changed

24 files changed

+276
-54
lines changed

clang/include/clang/Basic/Attr.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1195,6 +1195,7 @@ static llvm::Triple::EnvironmentType getEnvironmentType(llvm::StringRef Environm
11951195
.Case("callable", llvm::Triple::Callable)
11961196
.Case("mesh", llvm::Triple::Mesh)
11971197
.Case("amplification", llvm::Triple::Amplification)
1198+
.Case("rootsignature", llvm::Triple::RootSignature)
11981199
.Case("library", llvm::Triple::Library)
11991200
.Default(llvm::Triple::UnknownEnvironment);
12001201
}

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13152,6 +13152,8 @@ def err_hlsl_attribute_needs_intangible_type: Error<"attribute %0 can be used on
1315213152
def err_hlsl_incorrect_num_initializers: Error<
1315313153
"too %select{few|many}0 initializers in list for type %1 "
1315413154
"(expected %2 but found %3)">;
13155+
def err_hlsl_rootsignature_entry: Error<
13156+
"rootsignature specified as target environment but entry, %0, was not defined">;
1315513157

1315613158
def err_hlsl_operator_unsupported : Error<
1315713159
"the '%select{&|*|->}0' operator is unsupported in HLSL">;

clang/include/clang/Driver/Options.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9444,7 +9444,8 @@ def target_profile : DXCJoinedOrSeparate<"T">, MetaVarName<"<profile>">,
94449444
"cs_6_0, cs_6_1, cs_6_2, cs_6_3, cs_6_4, cs_6_5, cs_6_6, cs_6_7,"
94459445
"lib_6_3, lib_6_4, lib_6_5, lib_6_6, lib_6_7, lib_6_x,"
94469446
"ms_6_5, ms_6_6, ms_6_7,"
9447-
"as_6_5, as_6_6, as_6_7">;
9447+
"as_6_5, as_6_6, as_6_7,"
9448+
"rootsig_1_0, rootsig_1_1">;
94489449
def emit_pristine_llvm : DXCFlag<"emit-pristine-llvm">,
94499450
HelpText<"Emit pristine LLVM IR from the frontend by not running any LLVM passes at all."
94509451
"Same as -S + -emit-llvm + -disable-llvm-passes.">;

clang/include/clang/Parse/ParseHLSLRootSignature.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,8 @@ IdentifierInfo *ParseHLSLRootSignature(Sema &Actions,
240240
llvm::dxbc::RootSignatureVersion Version,
241241
StringLiteral *Signature);
242242

243+
void HandleRootSignatureTarget(Sema &S, StringRef EntryRootSig);
244+
243245
} // namespace hlsl
244246
} // namespace clang
245247

clang/include/clang/Sema/SemaHLSL.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,8 @@ class SemaHLSL : public SemaBase {
159159
RootSigOverrideIdent = DeclIdent;
160160
}
161161

162+
HLSLRootSignatureDecl *lookupRootSignatureOverrideDecl(DeclContext *DC) const;
163+
162164
// Returns true if any RootSignatureElement is invalid and a diagnostic was
163165
// produced
164166
bool

clang/lib/CodeGen/CGHLSLRuntime.cpp

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,18 +70,18 @@ void addDxilValVersion(StringRef ValVersionStr, llvm::Module &M) {
7070
DXILValMD->addOperand(Val);
7171
}
7272

73-
void addRootSignature(llvm::dxbc::RootSignatureVersion RootSigVer,
74-
ArrayRef<llvm::hlsl::rootsig::RootElement> Elements,
75-
llvm::Function *Fn, llvm::Module &M) {
73+
void addRootSignatureMD(llvm::dxbc::RootSignatureVersion RootSigVer,
74+
ArrayRef<llvm::hlsl::rootsig::RootElement> Elements,
75+
llvm::Function *Fn, llvm::Module &M) {
7676
auto &Ctx = M.getContext();
7777

7878
llvm::hlsl::rootsig::MetadataBuilder RSBuilder(Ctx, Elements);
7979
MDNode *RootSignature = RSBuilder.BuildRootSignature();
8080

8181
ConstantAsMetadata *Version = ConstantAsMetadata::get(ConstantInt::get(
8282
llvm::Type::getInt32Ty(Ctx), llvm::to_underlying(RootSigVer)));
83-
MDNode *MDVals =
84-
MDNode::get(Ctx, {ValueAsMetadata::get(Fn), RootSignature, Version});
83+
ValueAsMetadata *EntryFunc = Fn ? ValueAsMetadata::get(Fn) : nullptr;
84+
MDNode *MDVals = MDNode::get(Ctx, {EntryFunc, RootSignature, Version});
8585

8686
StringRef RootSignatureValKey = "dx.rootsignatures";
8787
auto *RootSignatureValMD = M.getOrInsertNamedMetadata(RootSignatureValKey);
@@ -449,6 +449,19 @@ void CGHLSLRuntime::addBuffer(const HLSLBufferDecl *BufDecl) {
449449
}
450450
}
451451

452+
void CGHLSLRuntime::addRootSignature(
453+
const HLSLRootSignatureDecl *SignatureDecl) {
454+
llvm::Module &M = CGM.getModule();
455+
Triple T(M.getTargetTriple());
456+
457+
// Generated later with the function decl if not targeting root signature
458+
if (T.getEnvironment() != Triple::EnvironmentType::RootSignature)
459+
return;
460+
461+
addRootSignatureMD(SignatureDecl->getVersion(),
462+
SignatureDecl->getRootElements(), nullptr, M);
463+
}
464+
452465
llvm::TargetExtType *
453466
CGHLSLRuntime::getHLSLBufferLayoutType(const RecordType *StructType) {
454467
const auto Entry = LayoutTypes.find(StructType);
@@ -685,8 +698,8 @@ void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
685698
for (const Attr *Attr : FD->getAttrs()) {
686699
if (const auto *RSAttr = dyn_cast<RootSignatureAttr>(Attr)) {
687700
auto *RSDecl = RSAttr->getSignatureDecl();
688-
addRootSignature(RSDecl->getVersion(), RSDecl->getRootElements(), EntryFn,
689-
M);
701+
addRootSignatureMD(RSDecl->getVersion(), RSDecl->getRootElements(),
702+
EntryFn, M);
690703
}
691704
}
692705
}

clang/lib/CodeGen/CGHLSLRuntime.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class VarDecl;
6464
class ParmVarDecl;
6565
class InitListExpr;
6666
class HLSLBufferDecl;
67+
class HLSLRootSignatureDecl;
6768
class HLSLVkBindingAttr;
6869
class HLSLResourceBindingAttr;
6970
class Type;
@@ -171,6 +172,7 @@ class CGHLSLRuntime {
171172
void generateGlobalCtorDtorCalls();
172173

173174
void addBuffer(const HLSLBufferDecl *D);
175+
void addRootSignature(const HLSLRootSignatureDecl *D);
174176
void finishCodeGen();
175177

176178
void setHLSLEntryAttributes(const FunctionDecl *FD, llvm::Function *Fn);

clang/lib/CodeGen/CodeGenModule.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7545,7 +7545,7 @@ void CodeGenModule::EmitTopLevelDecl(Decl *D) {
75457545
break;
75467546

75477547
case Decl::HLSLRootSignature:
7548-
// Will be handled by attached function
7548+
getHLSLRuntime().addRootSignature(cast<HLSLRootSignatureDecl>(D));
75497549
break;
75507550
case Decl::HLSLBuffer:
75517551
getHLSLRuntime().addBuffer(cast<HLSLBufferDecl>(D));

clang/lib/Driver/ToolChains/HLSL.cpp

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,15 @@ bool isLegalShaderModel(Triple &T) {
6262
VersionTuple MinVer(6, 5);
6363
return MinVer <= Version;
6464
} break;
65+
case Triple::EnvironmentType::RootSignature:
66+
VersionTuple MinVer(1, 0);
67+
VersionTuple MaxVer(1, 1);
68+
return MinVer <= Version && Version <= MaxVer;
6569
}
6670
return false;
6771
}
6872

69-
std::optional<std::string> tryParseProfile(StringRef Profile) {
73+
std::optional<llvm::Triple> tryParseTriple(StringRef Profile) {
7074
// [ps|vs|gs|hs|ds|cs|ms|as]_[major]_[minor]
7175
SmallVector<StringRef, 3> Parts;
7276
Profile.split(Parts, "_");
@@ -84,6 +88,7 @@ std::optional<std::string> tryParseProfile(StringRef Profile) {
8488
.Case("lib", Triple::EnvironmentType::Library)
8589
.Case("ms", Triple::EnvironmentType::Mesh)
8690
.Case("as", Triple::EnvironmentType::Amplification)
91+
.Case("rootsig", Triple::EnvironmentType::RootSignature)
8792
.Default(Triple::EnvironmentType::UnknownEnvironment);
8893
if (Kind == Triple::EnvironmentType::UnknownEnvironment)
8994
return std::nullopt;
@@ -147,8 +152,14 @@ std::optional<std::string> tryParseProfile(StringRef Profile) {
147152
T.setOSName(Triple::getOSTypeName(Triple::OSType::ShaderModel).str() +
148153
VersionTuple(Major, Minor).getAsString());
149154
T.setEnvironment(Kind);
150-
if (isLegalShaderModel(T))
151-
return T.getTriple();
155+
156+
return T;
157+
}
158+
159+
std::optional<std::string> tryParseProfile(StringRef Profile) {
160+
std::optional<llvm::Triple> MaybeT = tryParseTriple(Profile);
161+
if (MaybeT && isLegalShaderModel(*MaybeT))
162+
return MaybeT->getTriple();
152163
else
153164
return std::nullopt;
154165
}
@@ -258,6 +269,19 @@ bool checkExtensionArgsAreValid(ArrayRef<std::string> SpvExtensionArgs,
258269
}
259270
return AllValid;
260271
}
272+
273+
bool isRootSignatureTarget(StringRef Profile) {
274+
if (std::optional<llvm::Triple> T = tryParseTriple(Profile))
275+
return T->getEnvironment() == Triple::EnvironmentType::RootSignature;
276+
return false;
277+
}
278+
279+
bool isRootSignatureTarget(DerivedArgList &Args) {
280+
if (const Arg *A = Args.getLastArg(options::OPT_target_profile))
281+
return isRootSignatureTarget(A->getValue());
282+
return false;
283+
}
284+
261285
} // namespace
262286

263287
void tools::hlsl::Validator::ConstructJob(Compilation &C, const JobAction &JA,
@@ -317,6 +341,12 @@ void tools::hlsl::LLVMObjcopy::ConstructJob(Compilation &C, const JobAction &JA,
317341
CmdArgs.push_back(Frs);
318342
}
319343

344+
if (const Arg *A = Args.getLastArg(options::OPT_target_profile))
345+
if (isRootSignatureTarget(A->getValue())) {
346+
const char *Fos = Args.MakeArgString("--only-section=RTS0");
347+
CmdArgs.push_back(Fos);
348+
}
349+
320350
assert(CmdArgs.size() > 2 && "Unnecessary invocation of objcopy.");
321351

322352
C.addCommand(std::make_unique<Command>(JA, *this, ResponseFileSupport::None(),
@@ -493,7 +523,8 @@ bool HLSLToolChain::requiresBinaryTranslation(DerivedArgList &Args) const {
493523

494524
bool HLSLToolChain::requiresObjcopy(DerivedArgList &Args) const {
495525
return Args.hasArg(options::OPT_dxc_Fo) &&
496-
Args.hasArg(options::OPT_dxc_strip_rootsignature);
526+
(Args.hasArg(options::OPT_dxc_strip_rootsignature) ||
527+
isRootSignatureTarget(Args));
497528
}
498529

499530
bool HLSLToolChain::isLastJob(DerivedArgList &Args,

clang/lib/Frontend/FrontendActions.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1310,16 +1310,27 @@ void HLSLFrontendAction::ExecuteAction() {
13101310
/*CodeCompleteConsumer=*/nullptr);
13111311
Sema &S = CI.getSema();
13121312

1313+
auto &TargetInfo = CI.getASTContext().getTargetInfo();
1314+
bool IsRootSignatureTarget =
1315+
TargetInfo.getTriple().getEnvironment() == llvm::Triple::RootSignature;
1316+
StringRef HLSLEntry = TargetInfo.getTargetOpts().HLSLEntry;
1317+
13131318
// Register HLSL specific callbacks
13141319
auto LangOpts = CI.getLangOpts();
1320+
StringRef RootSigName =
1321+
IsRootSignatureTarget ? HLSLEntry : LangOpts.HLSLRootSigOverride;
1322+
13151323
auto MacroCallback = std::make_unique<InjectRootSignatureCallback>(
1316-
S, LangOpts.HLSLRootSigOverride, LangOpts.HLSLRootSigVer);
1324+
S, RootSigName, LangOpts.HLSLRootSigVer);
13171325

13181326
Preprocessor &PP = CI.getPreprocessor();
13191327
PP.addPPCallbacks(std::move(MacroCallback));
13201328

1321-
// Invoke as normal
1322-
WrapperFrontendAction::ExecuteAction();
1329+
// If we are targeting a root signature, invoke custom handling
1330+
if (IsRootSignatureTarget)
1331+
return hlsl::HandleRootSignatureTarget(S, HLSLEntry);
1332+
else // otherwise, invoke as normal
1333+
return WrapperFrontendAction::ExecuteAction();
13231334
}
13241335

13251336
HLSLFrontendAction::HLSLFrontendAction(

0 commit comments

Comments
 (0)