Skip to content

Commit f8d7c6e

Browse files
jiangzhaomingDawn LUCI CQ
authored andcommitted
Dawn: Support streaming ShaderModuleParseResult
This CL add StreamOut support for std::optional, and add a wrapper UnsafeUnserializedValue to replace CacheKey::UnsafeUnkeyedValue and hold unserializable value that will do nothing when streaming in and will set to default constructed value when streamed out to. These help making ShaderModuleParseResult serializable for shader module blob cache. Bug: 42240459, 402772740 Change-Id: Ic6199e46a1be34c4d03047af361dfab1ef807817 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/244094 Reviewed-by: Corentin Wallez <cwallez@chromium.org> Commit-Queue: Zhaoming Jiang <zhaoming.jiang@microsoft.com>
1 parent 2987f3c commit f8d7c6e

File tree

13 files changed

+208
-145
lines changed

13 files changed

+208
-145
lines changed

src/dawn/native/CacheKey.h

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -40,34 +40,8 @@ class CacheKey : public stream::ByteVectorSink {
4040
using stream::ByteVectorSink::ByteVectorSink;
4141

4242
enum class Type { ComputePipeline, RenderPipeline, Shader };
43-
44-
template <typename T>
45-
class UnsafeUnkeyedValue {
46-
public:
47-
UnsafeUnkeyedValue() = default;
48-
// NOLINTNEXTLINE(runtime/explicit) allow implicit construction to decrease verbosity
49-
UnsafeUnkeyedValue(T&& value) : mValue(std::forward<T>(value)) {}
50-
51-
const T& UnsafeGetValue() const { return mValue; }
52-
53-
// Friend definition of StreamIn which can be found by ADL to override
54-
// stream::StreamIn<T>.
55-
friend constexpr void StreamIn(stream::Sink*, const UnsafeUnkeyedValue&) {}
56-
57-
// Enabling DAWN_SERIALIZABLE classes with UnsafeUnkeyedValue member to use default equality
58-
// operator. Equality comparison always returns true for the same type UnsafeUnkeyedValues.
59-
bool operator==(const UnsafeUnkeyedValue<T>& other) const { return true; }
60-
61-
private:
62-
T mValue;
63-
};
6443
};
6544

66-
template <typename T>
67-
CacheKey::UnsafeUnkeyedValue<T> UnsafeUnkeyedValue(T&& value) {
68-
return CacheKey::UnsafeUnkeyedValue<T>(std::forward<T>(value));
69-
}
70-
7145
} // namespace dawn::native
7246

7347
#endif // SRC_DAWN_NATIVE_CACHEKEY_H_

src/dawn/native/Serializable.h

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,46 @@ class Serializable {
6363
return CreateBlob(std::move(sink));
6464
}
6565
};
66+
67+
// UnsafeUnserializedValue holds a value of type T that does nothing when StreamIn to a sink, calls
68+
// default constructor when StreamOut from a source, and always compares as equal between objects of
69+
// the same type UnsafeUnserializedValue<T>. This is used for members in DAWN_SERIALIZABLE or
70+
// DAWN_MAKE_CACHE_REQUEST to prevent a member to get streamed into cache key, or enable having
71+
// unserializabled fields that get computed by cache missed function and returned together with
72+
// cached fields.
73+
template <typename T>
74+
class UnsafeUnserializedValue {
75+
public:
76+
UnsafeUnserializedValue() = default;
77+
explicit UnsafeUnserializedValue(T&& value) : mValue(std::forward<T>(value)) {}
78+
UnsafeUnserializedValue(const UnsafeUnserializedValue<T>& other)
79+
: mValue(other.UnsafeGetValue()) {}
80+
UnsafeUnserializedValue<T>& operator=(UnsafeUnserializedValue<T>&& other) {
81+
mValue = std::move(other.UnsafeGetValue());
82+
return *this;
83+
}
84+
85+
constexpr const T& UnsafeGetValue() const { return mValue; }
86+
constexpr T& UnsafeGetValue() { return mValue; }
87+
88+
friend constexpr void StreamIn(stream::Sink*, const UnsafeUnserializedValue<T>&) {}
89+
friend MaybeError StreamOut(stream::Source*, UnsafeUnserializedValue<T>* out) {
90+
// Call default constructor to initialize the value.
91+
out->mValue = T();
92+
return {};
93+
}
94+
// Enabling DAWN_SERIALIZABLE classes with UnsafeUnserializedOptional member to use default
95+
// equality operator. Equality comparison always returns true for the same type.
96+
bool operator==(const UnsafeUnserializedValue<T>& other) const { return true; }
97+
98+
protected:
99+
T mValue;
100+
};
101+
102+
// Template deduction guide for UnsafeUnserializedValue to enable deducting type of
103+
// UnsafeUnserializedValue(T{}) to UnsafeUnserializedValue<T>.
104+
template <typename T>
105+
UnsafeUnserializedValue(T&& value) -> UnsafeUnserializedValue<std::decay_t<T>>;
66106
} // namespace dawn::native
67107

68108
// Helper macro to define a struct or class along with VisitAll methods to call

src/dawn/native/d3d/D3DCompilationRequest.h

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -60,38 +60,38 @@ enum class Compiler { FXC, DXC };
6060
using InterStageShaderVariablesMask = std::bitset<tint::hlsl::writer::kMaxInterStageLocations>;
6161
using SubstituteOverrideConfig = std::unordered_map<tint::OverrideId, double>;
6262

63-
#define HLSL_COMPILATION_REQUEST_MEMBERS(X) \
64-
X(ShaderModuleBase::ShaderModuleHash, shaderModuleHash) \
65-
X(CacheKey::UnsafeUnkeyedValue<ShaderModuleBase::ScopedUseTintProgram>, inputProgram) \
66-
X(std::string_view, entryPointName) \
67-
X(SingleShaderStage, stage) \
68-
X(uint32_t, shaderModel) \
69-
X(uint32_t, compileFlags) \
70-
X(Compiler, compiler) \
71-
X(uint64_t, compilerVersion) \
72-
X(std::wstring_view, dxcShaderProfile) \
73-
X(std::string_view, fxcShaderProfile) \
74-
X(uint32_t, firstIndexOffsetShaderRegister) \
75-
X(uint32_t, firstIndexOffsetRegisterSpace) \
76-
X(tint::hlsl::writer::Options, tintOptions) \
77-
X(SubstituteOverrideConfig, substituteOverrideConfig) \
78-
X(LimitsForCompilationRequest, limits) \
79-
X(CacheKey::UnsafeUnkeyedValue<LimitsForCompilationRequest>, adapterSupportedLimits) \
80-
X(uint32_t, maxSubgroupSize) \
81-
X(bool, disableSymbolRenaming) \
82-
X(bool, dumpShaders) \
63+
#define HLSL_COMPILATION_REQUEST_MEMBERS(X) \
64+
X(ShaderModuleBase::ShaderModuleHash, shaderModuleHash) \
65+
X(UnsafeUnserializedValue<ShaderModuleBase::ScopedUseTintProgram>, inputProgram) \
66+
X(std::string_view, entryPointName) \
67+
X(SingleShaderStage, stage) \
68+
X(uint32_t, shaderModel) \
69+
X(uint32_t, compileFlags) \
70+
X(Compiler, compiler) \
71+
X(uint64_t, compilerVersion) \
72+
X(std::wstring_view, dxcShaderProfile) \
73+
X(std::string_view, fxcShaderProfile) \
74+
X(uint32_t, firstIndexOffsetShaderRegister) \
75+
X(uint32_t, firstIndexOffsetRegisterSpace) \
76+
X(tint::hlsl::writer::Options, tintOptions) \
77+
X(SubstituteOverrideConfig, substituteOverrideConfig) \
78+
X(LimitsForCompilationRequest, limits) \
79+
X(UnsafeUnserializedValue<LimitsForCompilationRequest>, adapterSupportedLimits) \
80+
X(uint32_t, maxSubgroupSize) \
81+
X(bool, disableSymbolRenaming) \
82+
X(bool, dumpShaders) \
8383
X(bool, useTintIR)
8484

85-
#define D3D_BYTECODE_COMPILATION_REQUEST_MEMBERS(X) \
86-
X(bool, hasShaderF16Feature) \
87-
X(uint32_t, compileFlags) \
88-
X(Compiler, compiler) \
89-
X(uint64_t, compilerVersion) \
90-
X(std::wstring_view, dxcShaderProfile) \
91-
X(std::string_view, fxcShaderProfile) \
92-
X(CacheKey::UnsafeUnkeyedValue<pD3DCompile>, d3dCompile) \
93-
X(CacheKey::UnsafeUnkeyedValue<IDxcLibrary*>, dxcLibrary) \
94-
X(CacheKey::UnsafeUnkeyedValue<IDxcCompiler3*>, dxcCompiler)
85+
#define D3D_BYTECODE_COMPILATION_REQUEST_MEMBERS(X) \
86+
X(bool, hasShaderF16Feature) \
87+
X(uint32_t, compileFlags) \
88+
X(Compiler, compiler) \
89+
X(uint64_t, compilerVersion) \
90+
X(std::wstring_view, dxcShaderProfile) \
91+
X(std::string_view, fxcShaderProfile) \
92+
X(UnsafeUnserializedValue<pD3DCompile>, d3dCompile) \
93+
X(UnsafeUnserializedValue<IDxcLibrary*>, dxcLibrary) \
94+
X(UnsafeUnserializedValue<IDxcCompiler3*>, dxcCompiler)
9595

9696
DAWN_SERIALIZABLE(struct, HlslCompilationRequest, HLSL_COMPILATION_REQUEST_MEMBERS){};
9797
#undef HLSL_COMPILATION_REQUEST_MEMBERS
@@ -104,7 +104,7 @@ DAWN_SERIALIZABLE(struct,
104104
#define D3D_COMPILATION_REQUEST_MEMBERS(X) \
105105
X(HlslCompilationRequest, hlsl) \
106106
X(D3DBytecodeCompilationRequest, bytecode) \
107-
X(CacheKey::UnsafeUnkeyedValue<dawn::platform::Platform*>, tracePlatform)
107+
X(UnsafeUnserializedValue<dawn::platform::Platform*>, tracePlatform)
108108

109109
DAWN_MAKE_CACHE_REQUEST(D3DCompilationRequest, D3D_COMPILATION_REQUEST_MEMBERS);
110110
#undef D3D_COMPILATION_REQUEST_MEMBERS

src/dawn/native/d3d/ShaderUtils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ ResultOrError<ComPtr<ID3DBlob>> CompileShaderFXC(const d3d::D3DBytecodeCompilati
228228
}
229229

230230
MaybeError TranslateToHLSL(d3d::HlslCompilationRequest r,
231-
CacheKey::UnsafeUnkeyedValue<dawn::platform::Platform*> tracePlatform,
231+
UnsafeUnserializedValue<dawn::platform::Platform*> tracePlatform,
232232
CompiledShader* compiledShader) {
233233
tint::ast::transform::Manager transformManager;
234234
tint::ast::transform::DataMap transformInputs;

src/dawn/native/d3d11/ShaderModuleD3D11.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ ResultOrError<d3d::CompiledShader> ShaderModule::Compile(
9292
const bool useTintIR = device->IsToggleEnabled(Toggle::UseTintIR);
9393

9494
d3d::D3DCompilationRequest req = {};
95-
req.tracePlatform = UnsafeUnkeyedValue(device->GetPlatform());
95+
req.tracePlatform = UnsafeUnserializedValue(device->GetPlatform());
9696
req.hlsl.shaderModel = 50;
9797
req.hlsl.disableSymbolRenaming = device->IsToggleEnabled(Toggle::DisableSymbolRenaming);
9898
req.hlsl.dumpShaders = device->IsToggleEnabled(Toggle::DumpShaders);
@@ -104,7 +104,8 @@ ResultOrError<d3d::CompiledShader> ShaderModule::Compile(
104104

105105
// D3D11 only supports FXC.
106106
req.bytecode.compiler = d3d::Compiler::FXC;
107-
req.bytecode.d3dCompile = std::move(pD3DCompile{device->GetFunctions()->d3dCompile});
107+
req.bytecode.d3dCompile =
108+
UnsafeUnserializedValue(pD3DCompile{device->GetFunctions()->d3dCompile});
108109
req.bytecode.compilerVersion = D3D_COMPILER_VERSION;
109110
DAWN_ASSERT(device->GetDeviceInfo().shaderModel == 50);
110111
switch (stage) {
@@ -192,7 +193,7 @@ ResultOrError<d3d::CompiledShader> ShaderModule::Compile(
192193
}
193194

194195
req.hlsl.shaderModuleHash = GetHash();
195-
req.hlsl.inputProgram = UseTintProgram();
196+
req.hlsl.inputProgram = UnsafeUnserializedValue(UseTintProgram());
196197
req.hlsl.entryPointName = programmableStage.entryPoint.c_str();
197198
req.hlsl.stage = stage;
198199

@@ -228,8 +229,8 @@ ResultOrError<d3d::CompiledShader> ShaderModule::Compile(
228229

229230
req.hlsl.substituteOverrideConfig = BuildSubstituteOverridesTransformConfig(programmableStage);
230231
req.hlsl.limits = LimitsForCompilationRequest::Create(device->GetLimits().v1);
231-
req.hlsl.adapterSupportedLimits =
232-
LimitsForCompilationRequest::Create(device->GetAdapter()->GetLimits().v1);
232+
req.hlsl.adapterSupportedLimits = UnsafeUnserializedValue(
233+
LimitsForCompilationRequest::Create(device->GetAdapter()->GetLimits().v1));
233234
req.hlsl.maxSubgroupSize = device->GetAdapter()->GetPhysicalDevice()->GetSubgroupMaxSize();
234235

235236
req.hlsl.tintOptions.disable_robustness = !device->IsRobustnessEnabled();

src/dawn/native/d3d12/ShaderModuleD3D12.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ ResultOrError<d3d::CompiledShader> ShaderModule::Compile(
134134
const bool useTintIR = device->IsToggleEnabled(Toggle::UseTintIR);
135135

136136
d3d::D3DCompilationRequest req = {};
137-
req.tracePlatform = UnsafeUnkeyedValue(device->GetPlatform());
137+
req.tracePlatform = UnsafeUnserializedValue(device->GetPlatform());
138138
req.hlsl.shaderModel = ToBackend(device->GetPhysicalDevice())
139139
->GetAppliedShaderModelUnderToggles(device->GetTogglesState());
140140
req.hlsl.disableSymbolRenaming = device->IsToggleEnabled(Toggle::DisableSymbolRenaming);
@@ -154,13 +154,14 @@ ResultOrError<d3d::CompiledShader> ShaderModule::Compile(
154154
ToBackend(device->GetPhysicalDevice())->GetBackend()->GetDxcVersion();
155155

156156
req.bytecode.compiler = d3d::Compiler::DXC;
157-
req.bytecode.dxcLibrary = device->GetDxcLibrary().Get();
158-
req.bytecode.dxcCompiler = device->GetDxcCompiler().Get();
157+
req.bytecode.dxcLibrary = UnsafeUnserializedValue(device->GetDxcLibrary().Get());
158+
req.bytecode.dxcCompiler = UnsafeUnserializedValue(device->GetDxcCompiler().Get());
159159
req.bytecode.compilerVersion = dxcVersionInfo.DxcCompilerVersion;
160160
req.bytecode.dxcShaderProfile = device->GetDxcShaderProfiles()[stage];
161161
} else {
162162
req.bytecode.compiler = d3d::Compiler::FXC;
163-
req.bytecode.d3dCompile = std::move(pD3DCompile{device->GetFunctions()->d3dCompile});
163+
req.bytecode.d3dCompile =
164+
UnsafeUnserializedValue(pD3DCompile{device->GetFunctions()->d3dCompile});
164165
req.bytecode.compilerVersion = D3D_COMPILER_VERSION;
165166
switch (stage) {
166167
case SingleShaderStage::Vertex:
@@ -322,7 +323,7 @@ ResultOrError<d3d::CompiledShader> ShaderModule::Compile(
322323
}
323324

324325
req.hlsl.shaderModuleHash = GetHash();
325-
req.hlsl.inputProgram = UseTintProgram();
326+
req.hlsl.inputProgram = UnsafeUnserializedValue(UseTintProgram());
326327
req.hlsl.entryPointName = programmableStage.entryPoint.c_str();
327328
req.hlsl.stage = stage;
328329
if (!useTintIR) {
@@ -380,8 +381,8 @@ ResultOrError<d3d::CompiledShader> ShaderModule::Compile(
380381
device->IsToggleEnabled(Toggle::EnableIntegerRangeAnalysisInRobustness);
381382

382383
req.hlsl.limits = LimitsForCompilationRequest::Create(device->GetLimits().v1);
383-
req.hlsl.adapterSupportedLimits =
384-
LimitsForCompilationRequest::Create(device->GetAdapter()->GetLimits().v1);
384+
req.hlsl.adapterSupportedLimits = UnsafeUnserializedValue(
385+
LimitsForCompilationRequest::Create(device->GetAdapter()->GetLimits().v1));
385386
req.hlsl.maxSubgroupSize = device->GetAdapter()->GetPhysicalDevice()->GetSubgroupMaxSize();
386387

387388
CacheResult<d3d::CompiledShader> compiledShader;

src/dawn/native/metal/ShaderModuleMTL.mm

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -55,19 +55,19 @@
5555
using OptionalVertexPullingTransformConfig = std::optional<tint::VertexPullingConfig>;
5656
using SubstituteOverrideConfig = std::unordered_map<tint::OverrideId, double>;
5757

58-
#define MSL_COMPILATION_REQUEST_MEMBERS(X) \
59-
X(SingleShaderStage, stage) \
60-
X(ShaderModuleBase::ShaderModuleHash, shaderModuleHash) \
61-
X(CacheKey::UnsafeUnkeyedValue<ShaderModuleBase::ScopedUseTintProgram>, inputProgram) \
62-
X(SubstituteOverrideConfig, substituteOverrideConfig) \
63-
X(LimitsForCompilationRequest, limits) \
64-
X(CacheKey::UnsafeUnkeyedValue<LimitsForCompilationRequest>, adapterSupportedLimits) \
65-
X(uint32_t, maxSubgroupSize) \
66-
X(std::string, entryPointName) \
67-
X(bool, usesSubgroupMatrix) \
68-
X(bool, disableSymbolRenaming) \
69-
X(tint::msl::writer::Options, tintOptions) \
70-
X(CacheKey::UnsafeUnkeyedValue<dawn::platform::Platform*>, platform)
58+
#define MSL_COMPILATION_REQUEST_MEMBERS(X) \
59+
X(SingleShaderStage, stage) \
60+
X(ShaderModuleBase::ShaderModuleHash, shaderModuleHash) \
61+
X(UnsafeUnserializedValue<ShaderModuleBase::ScopedUseTintProgram>, inputProgram) \
62+
X(SubstituteOverrideConfig, substituteOverrideConfig) \
63+
X(LimitsForCompilationRequest, limits) \
64+
X(UnsafeUnserializedValue<LimitsForCompilationRequest>, adapterSupportedLimits) \
65+
X(uint32_t, maxSubgroupSize) \
66+
X(std::string, entryPointName) \
67+
X(bool, usesSubgroupMatrix) \
68+
X(bool, disableSymbolRenaming) \
69+
X(tint::msl::writer::Options, tintOptions) \
70+
X(UnsafeUnserializedValue<dawn::platform::Platform*>, platform)
7171

7272
DAWN_MAKE_CACHE_REQUEST(MslCompilationRequest, MSL_COMPILATION_REQUEST_MEMBERS);
7373
#undef MSL_COMPILATION_REQUEST_MEMBERS
@@ -275,12 +275,12 @@
275275
MslCompilationRequest req = {};
276276
req.stage = stage;
277277
req.shaderModuleHash = programmableStage.module->GetHash();
278-
req.inputProgram = programmableStage.module->UseTintProgram();
278+
req.inputProgram = UnsafeUnserializedValue(programmableStage.module->UseTintProgram());
279279
req.substituteOverrideConfig = BuildSubstituteOverridesTransformConfig(programmableStage);
280280
req.entryPointName = programmableStage.entryPoint.c_str();
281281
req.disableSymbolRenaming = device->IsToggleEnabled(Toggle::DisableSymbolRenaming);
282282
req.usesSubgroupMatrix = programmableStage.metadata->usesSubgroupMatrix;
283-
req.platform = UnsafeUnkeyedValue(device->GetPlatform());
283+
req.platform = UnsafeUnserializedValue(device->GetPlatform());
284284

285285
req.tintOptions.strip_all_names = !req.disableSymbolRenaming;
286286
req.tintOptions.remapped_entry_point_name = device->GetIsolatedEntryPointName();
@@ -305,8 +305,8 @@
305305
device->IsToggleEnabled(Toggle::EnableIntegerRangeAnalysisInRobustness);
306306

307307
req.limits = LimitsForCompilationRequest::Create(device->GetLimits().v1);
308-
req.adapterSupportedLimits =
309-
LimitsForCompilationRequest::Create(device->GetAdapter()->GetLimits().v1);
308+
req.adapterSupportedLimits = UnsafeUnserializedValue(
309+
LimitsForCompilationRequest::Create(device->GetAdapter()->GetLimits().v1));
310310
req.maxSubgroupSize = device->GetAdapter()->GetPhysicalDevice()->GetSubgroupMaxSize();
311311

312312
CacheResult<MslCompilation> mslCompilation;

0 commit comments

Comments
 (0)