Skip to content

Commit 65e8b09

Browse files
committed
Renderer: Move reflect_spirv to RenderingShaderContainer
This change introduces a new protected type, `ReflectedShaderStage` to `RenderingShaderContainer` that derived types use to access SPIR-V and the reflected module, `SpvReflectShaderModule` allowing implementations to use the reflection information to compile their platform-specific module. * Fixes memory leak in `reflect_spirv` that would not deallocate the `SpvReflectShaderModule` if an error occurred. * Removes unnecessary allocation when creating `SpvReflectShaderModule` by passing `NO_COPY` flag to `spvReflectCreateShaderModule2` constructor function. * Replaces `VectorView` with `Span` for consistency * Fixes unnecessary allocations in D3D12 shader container in `_convert_spirv_to_nir` and `_convert_spirv_to_dxil` which implicitly converted the old `VectorView` to a `Vector`
1 parent 9283328 commit 65e8b09

12 files changed

+425
-391
lines changed

drivers/d3d12/rendering_shader_container_d3d12.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ uint32_t RenderingShaderContainerD3D12::_to_bytes_footer_extra_data(uint8_t *p_b
268268
}
269269

270270
#if NIR_ENABLED
271-
bool RenderingShaderContainerD3D12::_convert_spirv_to_nir(const Vector<RenderingDeviceCommons::ShaderStageSPIRVData> &p_spirv, const nir_shader_compiler_options *p_compiler_options, HashMap<int, nir_shader *> &r_stages_nir_shaders, Vector<RenderingDeviceCommons::ShaderStage> &r_stages, BitField<RenderingDeviceCommons::ShaderStage> &r_stages_processed) {
271+
bool RenderingShaderContainerD3D12::_convert_spirv_to_nir(Span<ReflectedShaderStage> p_spirv, const nir_shader_compiler_options *p_compiler_options, HashMap<int, nir_shader *> &r_stages_nir_shaders, Vector<RenderingDeviceCommons::ShaderStage> &r_stages, BitField<RenderingDeviceCommons::ShaderStage> &r_stages_processed) {
272272
r_stages_processed.clear();
273273

274274
dxil_spirv_runtime_conf dxil_runtime_conf = {};
@@ -287,7 +287,7 @@ bool RenderingShaderContainerD3D12::_convert_spirv_to_nir(const Vector<Rendering
287287
dxil_runtime_conf.inferred_read_only_images_as_srvs = false;
288288

289289
// Translate SPIR-V to NIR.
290-
for (int64_t i = 0; i < p_spirv.size(); i++) {
290+
for (uint64_t i = 0; i < p_spirv.size(); i++) {
291291
RenderingDeviceCommons::ShaderStage stage = p_spirv[i].shader_stage;
292292
RenderingDeviceCommons::ShaderStage stage_flag = (RenderingDeviceCommons::ShaderStage)(1 << stage);
293293
r_stages.push_back(stage);
@@ -302,9 +302,10 @@ bool RenderingShaderContainerD3D12::_convert_spirv_to_nir(const Vector<Rendering
302302
MESA_SHADER_COMPUTE, // SHADER_STAGE_COMPUTE
303303
};
304304

305+
Span<uint32_t> code = p_spirv[i].spirv();
305306
nir_shader *shader = spirv_to_nir(
306-
(const uint32_t *)(p_spirv[i].spirv.ptr()),
307-
p_spirv[i].spirv.size() / sizeof(uint32_t),
307+
code.ptr(),
308+
code.size(),
308309
nullptr,
309310
0,
310311
SPIRV_TO_MESA_STAGES[stage],
@@ -429,7 +430,7 @@ bool RenderingShaderContainerD3D12::_convert_nir_to_dxil(const HashMap<int, nir_
429430
return true;
430431
}
431432

432-
bool RenderingShaderContainerD3D12::_convert_spirv_to_dxil(const Vector<RenderingDeviceCommons::ShaderStageSPIRVData> &p_spirv, HashMap<RenderingDeviceCommons::ShaderStage, Vector<uint8_t>> &r_dxil_blobs, Vector<RenderingDeviceCommons::ShaderStage> &r_stages, BitField<RenderingDeviceCommons::ShaderStage> &r_stages_processed) {
433+
bool RenderingShaderContainerD3D12::_convert_spirv_to_dxil(Span<ReflectedShaderStage> p_spirv, HashMap<RenderingDeviceCommons::ShaderStage, Vector<uint8_t>> &r_dxil_blobs, Vector<RenderingDeviceCommons::ShaderStage> &r_stages, BitField<RenderingDeviceCommons::ShaderStage> &r_stages_processed) {
433434
r_dxil_blobs.clear();
434435

435436
HashMap<int, nir_shader *> stages_nir_shaders;
@@ -764,7 +765,7 @@ void RenderingShaderContainerD3D12::_nir_report_bitcode_bit_offset(uint64_t p_bi
764765
}
765766
#endif
766767

767-
void RenderingShaderContainerD3D12::_set_from_shader_reflection_post(const String &p_shader_name, const RenderingDeviceCommons::ShaderReflection &p_reflection) {
768+
void RenderingShaderContainerD3D12::_set_from_shader_reflection_post(const RenderingDeviceCommons::ShaderReflection &p_reflection) {
768769
reflection_binding_set_uniforms_data_d3d12.resize(reflection_binding_set_uniforms_data.size());
769770
reflection_specialization_data_d3d12.resize(reflection_specialization_data.size());
770771

@@ -780,7 +781,7 @@ void RenderingShaderContainerD3D12::_set_from_shader_reflection_post(const Strin
780781
}
781782
}
782783

783-
bool RenderingShaderContainerD3D12::_set_code_from_spirv(const Vector<RenderingDeviceCommons::ShaderStageSPIRVData> &p_spirv) {
784+
bool RenderingShaderContainerD3D12::_set_code_from_spirv(Span<ReflectedShaderStage> p_spirv) {
784785
#if NIR_ENABLED
785786
reflection_data_d3d12.nir_runtime_data_root_param_idx = UINT32_MAX;
786787

drivers/d3d12/rendering_shader_container_d3d12.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,9 @@ class RenderingShaderContainerD3D12 : public RenderingShaderContainer {
122122
uint32_t root_signature_crc = 0;
123123

124124
#if NIR_ENABLED
125-
bool _convert_spirv_to_nir(const Vector<RenderingDeviceCommons::ShaderStageSPIRVData> &p_spirv, const nir_shader_compiler_options *p_compiler_options, HashMap<int, nir_shader *> &r_stages_nir_shaders, Vector<RenderingDeviceCommons::ShaderStage> &r_stages, BitField<RenderingDeviceCommons::ShaderStage> &r_stages_processed);
125+
bool _convert_spirv_to_nir(Span<ReflectedShaderStage> p_spirv, const nir_shader_compiler_options *p_compiler_options, HashMap<int, nir_shader *> &r_stages_nir_shaders, Vector<RenderingDeviceCommons::ShaderStage> &r_stages, BitField<RenderingDeviceCommons::ShaderStage> &r_stages_processed);
126126
bool _convert_nir_to_dxil(const HashMap<int, nir_shader *> &p_stages_nir_shaders, BitField<RenderingDeviceCommons::ShaderStage> p_stages_processed, HashMap<RenderingDeviceCommons::ShaderStage, Vector<uint8_t>> &r_dxil_blobs);
127-
bool _convert_spirv_to_dxil(const Vector<RenderingDeviceCommons::ShaderStageSPIRVData> &p_spirv, HashMap<RenderingDeviceCommons::ShaderStage, Vector<uint8_t>> &r_dxil_blobs, Vector<RenderingDeviceCommons::ShaderStage> &r_stages, BitField<RenderingDeviceCommons::ShaderStage> &r_stages_processed);
127+
bool _convert_spirv_to_dxil(Span<ReflectedShaderStage> p_spirv, HashMap<RenderingDeviceCommons::ShaderStage, Vector<uint8_t>> &r_dxil_blobs, Vector<RenderingDeviceCommons::ShaderStage> &r_stages, BitField<RenderingDeviceCommons::ShaderStage> &r_stages_processed);
128128
bool _generate_root_signature(BitField<RenderingDeviceCommons::ShaderStage> p_stages_processed);
129129

130130
// GodotNirCallbacks.
@@ -146,8 +146,8 @@ class RenderingShaderContainerD3D12 : public RenderingShaderContainer {
146146
virtual uint32_t _to_bytes_reflection_binding_uniform_extra_data(uint8_t *p_bytes, uint32_t p_index) const override;
147147
virtual uint32_t _to_bytes_reflection_specialization_extra_data(uint8_t *p_bytes, uint32_t p_index) const override;
148148
virtual uint32_t _to_bytes_footer_extra_data(uint8_t *p_bytes) const override;
149-
virtual void _set_from_shader_reflection_post(const String &p_shader_name, const RenderingDeviceCommons::ShaderReflection &p_reflection) override;
150-
virtual bool _set_code_from_spirv(const Vector<RenderingDeviceCommons::ShaderStageSPIRVData> &p_spirv) override;
149+
virtual void _set_from_shader_reflection_post(const RenderingDeviceCommons::ShaderReflection &p_reflection) override;
150+
virtual bool _set_code_from_spirv(Span<ReflectedShaderStage> p_spirv) override;
151151

152152
public:
153153
struct ShaderReflectionD3D12 {

drivers/metal/rendering_shader_container_metal.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ class RenderingShaderContainerMetal : public RenderingShaderContainer {
292292

293293
virtual uint32_t _format() const override;
294294
virtual uint32_t _format_version() const override;
295-
virtual bool _set_code_from_spirv(const Vector<RenderingDeviceCommons::ShaderStageSPIRVData> &p_spirv) override;
295+
virtual bool _set_code_from_spirv(Span<ReflectedShaderStage> p_spirv) override;
296296
};
297297

298298
class RenderingShaderContainerFormatMetal : public RenderingShaderContainerFormat {

drivers/metal/rendering_shader_container_metal.mm

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@
253253
#pragma clang diagnostic push
254254
#pragma clang diagnostic ignored "-Wunguarded-availability"
255255

256-
bool RenderingShaderContainerMetal::_set_code_from_spirv(const Vector<RenderingDeviceCommons::ShaderStageSPIRVData> &p_spirv) {
256+
bool RenderingShaderContainerMetal::_set_code_from_spirv(Span<ReflectedShaderStage> p_spirv) {
257257
using namespace spirv_cross;
258258
using spirv_cross::CompilerMSL;
259259
using spirv_cross::Resource;
@@ -354,12 +354,11 @@
354354

355355
for (uint32_t i = 0; i < p_spirv.size(); i++) {
356356
StageData &stage_data = mtl_shaders.write[i];
357-
RD::ShaderStageSPIRVData const &v = p_spirv[i];
357+
const ReflectedShaderStage &v = p_spirv[i];
358358
RD::ShaderStage stage = v.shader_stage;
359359
char const *stage_name = RD::SHADER_STAGE_NAMES[stage];
360-
uint32_t const *const ir = reinterpret_cast<uint32_t const *const>(v.spirv.ptr());
361-
size_t word_count = v.spirv.size() / sizeof(uint32_t);
362-
Parser parser(ir, word_count);
360+
Span<uint32_t> spirv = v.spirv();
361+
Parser parser(spirv.ptr(), spirv.size());
363362
try {
364363
parser.parse();
365364
} catch (CompilerError &e) {

drivers/vulkan/rendering_shader_container_vulkan.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,21 +44,21 @@ uint32_t RenderingShaderContainerVulkan::_format_version() const {
4444
return FORMAT_VERSION;
4545
}
4646

47-
bool RenderingShaderContainerVulkan::_set_code_from_spirv(const Vector<RenderingDeviceCommons::ShaderStageSPIRVData> &p_spirv) {
47+
bool RenderingShaderContainerVulkan::_set_code_from_spirv(Span<ReflectedShaderStage> p_spirv) {
4848
PackedByteArray code_bytes;
4949
shaders.resize(p_spirv.size());
50-
for (int64_t i = 0; i < p_spirv.size(); i++) {
50+
for (uint64_t i = 0; i < p_spirv.size(); i++) {
5151
RenderingShaderContainer::Shader &shader = shaders.ptrw()[i];
52-
5352
if (debug_info_enabled) {
5453
// Store SPIR-V as is when debug info is required.
55-
shader.code_compressed_bytes = p_spirv[i].spirv;
54+
shader.code_compressed_bytes = p_spirv[i].spirv_data();
5655
shader.code_compression_flags = 0;
5756
shader.code_decompressed_size = 0;
5857
} else {
5958
// Encode into smolv.
59+
Span<uint8_t> spirv = p_spirv[i].spirv().reinterpret<uint8_t>();
6060
smolv::ByteArray smolv_bytes;
61-
bool smolv_encoded = smolv::Encode(p_spirv[i].spirv.ptr(), p_spirv[i].spirv.size(), smolv_bytes, smolv::kEncodeFlagStripDebugInfo);
61+
bool smolv_encoded = smolv::Encode(spirv.ptr(), spirv.size(), smolv_bytes, smolv::kEncodeFlagStripDebugInfo);
6262
ERR_FAIL_COND_V_MSG(!smolv_encoded, false, "Failed to compress SPIR-V into smolv.");
6363

6464
code_bytes.resize(smolv_bytes.size());

drivers/vulkan/rendering_shader_container_vulkan.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class RenderingShaderContainerVulkan : public RenderingShaderContainer {
4747
protected:
4848
virtual uint32_t _format() const override;
4949
virtual uint32_t _format_version() const override;
50-
virtual bool _set_code_from_spirv(const Vector<RenderingDeviceCommons::ShaderStageSPIRVData> &p_spirv) override;
50+
virtual bool _set_code_from_spirv(Span<ReflectedShaderStage> p_spirv) override;
5151

5252
public:
5353
RenderingShaderContainerVulkan(bool p_debug_info_enabled);

editor/export/shader_baker_export_plugin.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -428,15 +428,10 @@ void ShaderBakerExportPlugin::_process_work_item(WorkItem p_work_item) {
428428
Vector<RD::ShaderStageSPIRVData> spirv_data = ShaderRD::compile_stages(p_work_item.stage_sources);
429429
ERR_FAIL_COND_MSG(spirv_data.is_empty(), "Unable to retrieve SPIR-V data for shader");
430430

431-
RD::ShaderReflection shader_refl;
432-
Error err = RenderingDeviceCommons::reflect_spirv(spirv_data, shader_refl);
433-
ERR_FAIL_COND_MSG(err != OK, "Unable to reflect SPIR-V data that was compiled");
434-
435431
Ref<RenderingShaderContainer> shader_container = shader_container_format->create_container();
436-
shader_container->set_from_shader_reflection(p_work_item.shader_name, shader_refl);
437432

438433
// Compile shader binary from SPIR-V.
439-
bool code_compiled = shader_container->set_code_from_spirv(spirv_data);
434+
bool code_compiled = shader_container->set_code_from_spirv(p_work_item.shader_name, spirv_data);
440435
ERR_FAIL_COND_MSG(!code_compiled, vformat("Failed to compile code to native for SPIR-V."));
441436

442437
PackedByteArray shader_bytes = shader_container->to_bytes();

servers/rendering/rendering_device.cpp

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3366,19 +3366,12 @@ String RenderingDevice::_shader_uniform_debug(RID p_shader, int p_set) {
33663366
}
33673367

33683368
Vector<uint8_t> RenderingDevice::shader_compile_binary_from_spirv(const Vector<ShaderStageSPIRVData> &p_spirv, const String &p_shader_name) {
3369-
ShaderReflection shader_refl;
3370-
if (reflect_spirv(p_spirv, shader_refl) != OK) {
3371-
return Vector<uint8_t>();
3372-
}
3373-
33743369
const RenderingShaderContainerFormat &container_format = driver->get_shader_container_format();
33753370
Ref<RenderingShaderContainer> shader_container = container_format.create_container();
33763371
ERR_FAIL_COND_V(shader_container.is_null(), Vector<uint8_t>());
33773372

3378-
shader_container->set_from_shader_reflection(p_shader_name, shader_refl);
3379-
33803373
// Compile shader binary from SPIR-V.
3381-
bool code_compiled = shader_container->set_code_from_spirv(p_spirv);
3374+
bool code_compiled = shader_container->set_code_from_spirv(p_shader_name, p_spirv);
33823375
ERR_FAIL_COND_V_MSG(!code_compiled, Vector<uint8_t>(), vformat("Failed to compile code to native for SPIR-V."));
33833376

33843377
return shader_container->to_bytes();

0 commit comments

Comments
 (0)