Skip to content

Commit 0ea5223

Browse files
authored
[DXC] Add -metal flag to DXC driver (#130173)
This adds a flag to the DXC driver to enable calling the metal shader converter if it is available to convert the final shader output for metal.
1 parent 23a44b9 commit 0ea5223

File tree

8 files changed

+81
-1
lines changed

8 files changed

+81
-1
lines changed

clang/include/clang/Driver/Action.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,10 @@ class Action {
7575
LinkerWrapperJobClass,
7676
StaticLibJobClass,
7777
BinaryAnalyzeJobClass,
78+
BinaryTranslatorJobClass,
7879

7980
JobClassFirst = PreprocessJobClass,
80-
JobClassLast = BinaryAnalyzeJobClass
81+
JobClassLast = BinaryTranslatorJobClass
8182
};
8283

8384
// The offloading kind determines if this action is binded to a particular
@@ -675,6 +676,17 @@ class BinaryAnalyzeJobAction : public JobAction {
675676
}
676677
};
677678

679+
class BinaryTranslatorJobAction : public JobAction {
680+
void anchor() override;
681+
682+
public:
683+
BinaryTranslatorJobAction(Action *Input, types::ID Type);
684+
685+
static bool classof(const Action *A) {
686+
return A->getKind() == BinaryTranslatorJobClass;
687+
}
688+
};
689+
678690
} // namespace driver
679691
} // namespace clang
680692

clang/include/clang/Driver/Options.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9085,6 +9085,7 @@ def : Option<["/", "-"], "Qembed_debug", KIND_FLAG>, Group<dxc_Group>,
90859085
HelpText<"Embed PDB in shader container (ignored)">;
90869086
def spirv : DXCFlag<"spirv">,
90879087
HelpText<"Generate SPIR-V code">;
9088+
def metal : DXCFlag<"metal">, HelpText<"Generate Metal library">;
90889089
def fspv_target_env_EQ : Joined<["-"], "fspv-target-env=">, Group<dxc_Group>,
90899090
HelpText<"Specify the target environment">,
90909091
Values<"vulkan1.2, vulkan1.3">;

clang/lib/Driver/Action.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ const char *Action::getClassName(ActionClass AC) {
5050
return "static-lib-linker";
5151
case BinaryAnalyzeJobClass:
5252
return "binary-analyzer";
53+
case BinaryTranslatorJobClass:
54+
return "binary-translator";
5355
}
5456

5557
llvm_unreachable("invalid class");
@@ -459,3 +461,9 @@ void BinaryAnalyzeJobAction::anchor() {}
459461

460462
BinaryAnalyzeJobAction::BinaryAnalyzeJobAction(Action *Input, types::ID Type)
461463
: JobAction(BinaryAnalyzeJobClass, Input, Type) {}
464+
465+
void BinaryTranslatorJobAction::anchor() {}
466+
467+
BinaryTranslatorJobAction::BinaryTranslatorJobAction(Action *Input,
468+
types::ID Type)
469+
: JobAction(BinaryTranslatorJobClass, Input, Type) {}

clang/lib/Driver/Driver.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4669,6 +4669,16 @@ void Driver::BuildActions(Compilation &C, DerivedArgList &Args,
46694669
Actions.push_back(C.MakeAction<BinaryAnalyzeJobAction>(
46704670
LastAction, types::TY_DX_CONTAINER));
46714671
}
4672+
if (Args.getLastArg(options::OPT_metal)) {
4673+
Action *LastAction = Actions.back();
4674+
// Metal shader converter runs on DXIL containers, which can either be
4675+
// validated (in which case they are TY_DX_CONTAINER), or unvalidated
4676+
// (TY_OBJECT).
4677+
if (LastAction->getType() == types::TY_DX_CONTAINER ||
4678+
LastAction->getType() == types::TY_Object)
4679+
Actions.push_back(C.MakeAction<BinaryTranslatorJobAction>(
4680+
LastAction, types::TY_DX_CONTAINER));
4681+
}
46724682
}
46734683

46744684
// Claim ignored clang-cl options.

clang/lib/Driver/ToolChain.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,7 @@ Tool *ToolChain::getTool(Action::ActionClass AC) const {
639639
case Action::DsymutilJobClass:
640640
case Action::VerifyDebugInfoJobClass:
641641
case Action::BinaryAnalyzeJobClass:
642+
case Action::BinaryTranslatorJobClass:
642643
llvm_unreachable("Invalid tool kind.");
643644

644645
case Action::CompileJobClass:

clang/lib/Driver/ToolChains/HLSL.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,22 @@ void tools::hlsl::Validator::ConstructJob(Compilation &C, const JobAction &JA,
198198
Exec, CmdArgs, Inputs, Input));
199199
}
200200

201+
void tools::hlsl::MetalConverter::ConstructJob(
202+
Compilation &C, const JobAction &JA, const InputInfo &Output,
203+
const InputInfoList &Inputs, const ArgList &Args,
204+
const char *LinkingOutput) const {
205+
std::string MSCPath = getToolChain().GetProgramPath("metal-shaderconverter");
206+
ArgStringList CmdArgs;
207+
const InputInfo &Input = Inputs[0];
208+
CmdArgs.push_back(Input.getFilename());
209+
CmdArgs.push_back("-o");
210+
CmdArgs.push_back(Input.getFilename());
211+
212+
const char *Exec = Args.MakeArgString(MSCPath);
213+
C.addCommand(std::make_unique<Command>(JA, *this, ResponseFileSupport::None(),
214+
Exec, CmdArgs, Inputs, Input));
215+
}
216+
201217
/// DirectX Toolchain
202218
HLSLToolChain::HLSLToolChain(const Driver &D, const llvm::Triple &Triple,
203219
const ArgList &Args)
@@ -214,6 +230,10 @@ Tool *clang::driver::toolchains::HLSLToolChain::getTool(
214230
if (!Validator)
215231
Validator.reset(new tools::hlsl::Validator(*this));
216232
return Validator.get();
233+
case Action::BinaryTranslatorJobClass:
234+
if (!MetalConverter)
235+
MetalConverter.reset(new tools::hlsl::MetalConverter(*this));
236+
return MetalConverter.get();
217237
default:
218238
return ToolChain::getTool(AC);
219239
}

clang/lib/Driver/ToolChains/HLSL.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,19 @@ class LLVM_LIBRARY_VISIBILITY Validator : public Tool {
2929
const llvm::opt::ArgList &TCArgs,
3030
const char *LinkingOutput) const override;
3131
};
32+
33+
class LLVM_LIBRARY_VISIBILITY MetalConverter : public Tool {
34+
public:
35+
MetalConverter(const ToolChain &TC)
36+
: Tool("hlsl::MetalConverter", "metal-shaderconverter", TC) {}
37+
38+
bool hasIntegratedCPP() const override { return false; }
39+
40+
void ConstructJob(Compilation &C, const JobAction &JA,
41+
const InputInfo &Output, const InputInfoList &Inputs,
42+
const llvm::opt::ArgList &TCArgs,
43+
const char *LinkingOutput) const override;
44+
};
3245
} // namespace hlsl
3346
} // namespace tools
3447

@@ -57,6 +70,7 @@ class LLVM_LIBRARY_VISIBILITY HLSLToolChain : public ToolChain {
5770

5871
private:
5972
mutable std::unique_ptr<tools::hlsl::Validator> Validator;
73+
mutable std::unique_ptr<tools::hlsl::MetalConverter> MetalConverter;
6074
};
6175

6276
} // end namespace toolchains
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// RUN: %clang_dxc -T cs_6_0 %s -metal -Fo tmp.mtl -### 2>&1 | FileCheck %s
2+
// RUN: %clang_dxc -T cs_6_0 %s -metal -Vd -Fo tmp.mtl -### 2>&1 | FileCheck %s
3+
// CHECK: "{{.*}}metal-shaderconverter{{(.exe)?}}" "tmp.mtl" "-o" "tmp.mtl"
4+
5+
// RUN: %clang_dxc -T cs_6_0 %s -metal -### 2>&1 | FileCheck --check-prefix=NO_MTL %s
6+
// NO_MTL-NOT: metal-shaderconverter
7+
8+
RWBuffer<float4> In : register(u0, space0);
9+
RWBuffer<float4> Out : register(u1, space4);
10+
11+
[numthreads(1,1,1)]
12+
void main(uint GI : SV_GroupIndex) {
13+
Out[GI] = In[GI] * In[GI];
14+
}

0 commit comments

Comments
 (0)