Skip to content

Commit 4594de6

Browse files
committed
Merge pull request #111013 from stuartcarnie/shader_container_ext
Renderer: Move `reflect_spirv` to `RenderingShaderContainer`
2 parents 25de1a3 + 65e8b09 commit 4594de6

12 files changed

+425
-391
lines changed

drivers/d3d12/rendering_shader_container_d3d12.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ uint32_t RenderingShaderContainerD3D12::_to_bytes_footer_extra_data(uint8_t *p_b
268268
}
269269

270270
#if NIR_ENABLED
271-
bool RenderingShaderContainerD3D12::_convert_spirv_to_nir(const Vector<RenderingDeviceCommons::ShaderStageSPIRVData> &p_spirv, const nir_shader_compiler_options *p_compiler_options, HashMap<int, nir_shader *> &r_stages_nir_shaders, Vector<RenderingDeviceCommons::ShaderStage> &r_stages, BitField<RenderingDeviceCommons::ShaderStage> &r_stages_processed) {
271+
bool RenderingShaderContainerD3D12::_convert_spirv_to_nir(Span<ReflectedShaderStage> p_spirv, const nir_shader_compiler_options *p_compiler_options, HashMap<int, nir_shader *> &r_stages_nir_shaders, Vector<RenderingDeviceCommons::ShaderStage> &r_stages, BitField<RenderingDeviceCommons::ShaderStage> &r_stages_processed) {
272272
r_stages_processed.clear();
273273

274274
dxil_spirv_runtime_conf dxil_runtime_conf = {};
@@ -287,7 +287,7 @@ bool RenderingShaderContainerD3D12::_convert_spirv_to_nir(const Vector<Rendering
287287
dxil_runtime_conf.inferred_read_only_images_as_srvs = false;
288288

289289
// Translate SPIR-V to NIR.
290-
for (int64_t i = 0; i < p_spirv.size(); i++) {
290+
for (uint64_t i = 0; i < p_spirv.size(); i++) {
291291
RenderingDeviceCommons::ShaderStage stage = p_spirv[i].shader_stage;
292292
RenderingDeviceCommons::ShaderStage stage_flag = (RenderingDeviceCommons::ShaderStage)(1 << stage);
293293
r_stages.push_back(stage);
@@ -302,9 +302,10 @@ bool RenderingShaderContainerD3D12::_convert_spirv_to_nir(const Vector<Rendering
302302
MESA_SHADER_COMPUTE, // SHADER_STAGE_COMPUTE
303303
};
304304

305+
Span<uint32_t> code = p_spirv[i].spirv();
305306
nir_shader *shader = spirv_to_nir(
306-
(const uint32_t *)(p_spirv[i].spirv.ptr()),
307-
p_spirv[i].spirv.size() / sizeof(uint32_t),
307+
code.ptr(),
308+
code.size(),
308309
nullptr,
309310
0,
310311
SPIRV_TO_MESA_STAGES[stage],
@@ -429,7 +430,7 @@ bool RenderingShaderContainerD3D12::_convert_nir_to_dxil(const HashMap<int, nir_
429430
return true;
430431
}
431432

432-
bool RenderingShaderContainerD3D12::_convert_spirv_to_dxil(const Vector<RenderingDeviceCommons::ShaderStageSPIRVData> &p_spirv, HashMap<RenderingDeviceCommons::ShaderStage, Vector<uint8_t>> &r_dxil_blobs, Vector<RenderingDeviceCommons::ShaderStage> &r_stages, BitField<RenderingDeviceCommons::ShaderStage> &r_stages_processed) {
433+
bool RenderingShaderContainerD3D12::_convert_spirv_to_dxil(Span<ReflectedShaderStage> p_spirv, HashMap<RenderingDeviceCommons::ShaderStage, Vector<uint8_t>> &r_dxil_blobs, Vector<RenderingDeviceCommons::ShaderStage> &r_stages, BitField<RenderingDeviceCommons::ShaderStage> &r_stages_processed) {
433434
r_dxil_blobs.clear();
434435

435436
HashMap<int, nir_shader *> stages_nir_shaders;
@@ -764,7 +765,7 @@ void RenderingShaderContainerD3D12::_nir_report_bitcode_bit_offset(uint64_t p_bi
764765
}
765766
#endif
766767

767-
void RenderingShaderContainerD3D12::_set_from_shader_reflection_post(const String &p_shader_name, const RenderingDeviceCommons::ShaderReflection &p_reflection) {
768+
void RenderingShaderContainerD3D12::_set_from_shader_reflection_post(const RenderingDeviceCommons::ShaderReflection &p_reflection) {
768769
reflection_binding_set_uniforms_data_d3d12.resize(reflection_binding_set_uniforms_data.size());
769770
reflection_specialization_data_d3d12.resize(reflection_specialization_data.size());
770771

@@ -780,7 +781,7 @@ void RenderingShaderContainerD3D12::_set_from_shader_reflection_post(const Strin
780781
}
781782
}
782783

783-
bool RenderingShaderContainerD3D12::_set_code_from_spirv(const Vector<RenderingDeviceCommons::ShaderStageSPIRVData> &p_spirv) {
784+
bool RenderingShaderContainerD3D12::_set_code_from_spirv(Span<ReflectedShaderStage> p_spirv) {
784785
#if NIR_ENABLED
785786
reflection_data_d3d12.nir_runtime_data_root_param_idx = UINT32_MAX;
786787

drivers/d3d12/rendering_shader_container_d3d12.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,9 @@ class RenderingShaderContainerD3D12 : public RenderingShaderContainer {
122122
uint32_t root_signature_crc = 0;
123123

124124
#if NIR_ENABLED
125-
bool _convert_spirv_to_nir(const Vector<RenderingDeviceCommons::ShaderStageSPIRVData> &p_spirv, const nir_shader_compiler_options *p_compiler_options, HashMap<int, nir_shader *> &r_stages_nir_shaders, Vector<RenderingDeviceCommons::ShaderStage> &r_stages, BitField<RenderingDeviceCommons::ShaderStage> &r_stages_processed);
125+
bool _convert_spirv_to_nir(Span<ReflectedShaderStage> p_spirv, const nir_shader_compiler_options *p_compiler_options, HashMap<int, nir_shader *> &r_stages_nir_shaders, Vector<RenderingDeviceCommons::ShaderStage> &r_stages, BitField<RenderingDeviceCommons::ShaderStage> &r_stages_processed);
126126
bool _convert_nir_to_dxil(const HashMap<int, nir_shader *> &p_stages_nir_shaders, BitField<RenderingDeviceCommons::ShaderStage> p_stages_processed, HashMap<RenderingDeviceCommons::ShaderStage, Vector<uint8_t>> &r_dxil_blobs);
127-
bool _convert_spirv_to_dxil(const Vector<RenderingDeviceCommons::ShaderStageSPIRVData> &p_spirv, HashMap<RenderingDeviceCommons::ShaderStage, Vector<uint8_t>> &r_dxil_blobs, Vector<RenderingDeviceCommons::ShaderStage> &r_stages, BitField<RenderingDeviceCommons::ShaderStage> &r_stages_processed);
127+
bool _convert_spirv_to_dxil(Span<ReflectedShaderStage> p_spirv, HashMap<RenderingDeviceCommons::ShaderStage, Vector<uint8_t>> &r_dxil_blobs, Vector<RenderingDeviceCommons::ShaderStage> &r_stages, BitField<RenderingDeviceCommons::ShaderStage> &r_stages_processed);
128128
bool _generate_root_signature(BitField<RenderingDeviceCommons::ShaderStage> p_stages_processed);
129129

130130
// GodotNirCallbacks.
@@ -146,8 +146,8 @@ class RenderingShaderContainerD3D12 : public RenderingShaderContainer {
146146
virtual uint32_t _to_bytes_reflection_binding_uniform_extra_data(uint8_t *p_bytes, uint32_t p_index) const override;
147147
virtual uint32_t _to_bytes_reflection_specialization_extra_data(uint8_t *p_bytes, uint32_t p_index) const override;
148148
virtual uint32_t _to_bytes_footer_extra_data(uint8_t *p_bytes) const override;
149-
virtual void _set_from_shader_reflection_post(const String &p_shader_name, const RenderingDeviceCommons::ShaderReflection &p_reflection) override;
150-
virtual bool _set_code_from_spirv(const Vector<RenderingDeviceCommons::ShaderStageSPIRVData> &p_spirv) override;
149+
virtual void _set_from_shader_reflection_post(const RenderingDeviceCommons::ShaderReflection &p_reflection) override;
150+
virtual bool _set_code_from_spirv(Span<ReflectedShaderStage> p_spirv) override;
151151

152152
public:
153153
struct ShaderReflectionD3D12 {

drivers/metal/rendering_shader_container_metal.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ class RenderingShaderContainerMetal : public RenderingShaderContainer {
292292

293293
virtual uint32_t _format() const override;
294294
virtual uint32_t _format_version() const override;
295-
virtual bool _set_code_from_spirv(const Vector<RenderingDeviceCommons::ShaderStageSPIRVData> &p_spirv) override;
295+
virtual bool _set_code_from_spirv(Span<ReflectedShaderStage> p_spirv) override;
296296
};
297297

298298
class RenderingShaderContainerFormatMetal : public RenderingShaderContainerFormat {

drivers/metal/rendering_shader_container_metal.mm

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@
252252
#pragma clang diagnostic push
253253
#pragma clang diagnostic ignored "-Wunguarded-availability"
254254

255-
bool RenderingShaderContainerMetal::_set_code_from_spirv(const Vector<RenderingDeviceCommons::ShaderStageSPIRVData> &p_spirv) {
255+
bool RenderingShaderContainerMetal::_set_code_from_spirv(Span<ReflectedShaderStage> p_spirv) {
256256
using namespace spirv_cross;
257257
using spirv_cross::CompilerMSL;
258258
using spirv_cross::Resource;
@@ -353,12 +353,11 @@
353353

354354
for (uint32_t i = 0; i < p_spirv.size(); i++) {
355355
StageData &stage_data = mtl_shaders.write[i];
356-
RD::ShaderStageSPIRVData const &v = p_spirv[i];
356+
const ReflectedShaderStage &v = p_spirv[i];
357357
RD::ShaderStage stage = v.shader_stage;
358358
char const *stage_name = RD::SHADER_STAGE_NAMES[stage];
359-
uint32_t const *const ir = reinterpret_cast<uint32_t const *const>(v.spirv.ptr());
360-
size_t word_count = v.spirv.size() / sizeof(uint32_t);
361-
Parser parser(ir, word_count);
359+
Span<uint32_t> spirv = v.spirv();
360+
Parser parser(spirv.ptr(), spirv.size());
362361
try {
363362
parser.parse();
364363
} catch (CompilerError &e) {

drivers/vulkan/rendering_shader_container_vulkan.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,21 +44,21 @@ uint32_t RenderingShaderContainerVulkan::_format_version() const {
4444
return FORMAT_VERSION;
4545
}
4646

47-
bool RenderingShaderContainerVulkan::_set_code_from_spirv(const Vector<RenderingDeviceCommons::ShaderStageSPIRVData> &p_spirv) {
47+
bool RenderingShaderContainerVulkan::_set_code_from_spirv(Span<ReflectedShaderStage> p_spirv) {
4848
PackedByteArray code_bytes;
4949
shaders.resize(p_spirv.size());
50-
for (int64_t i = 0; i < p_spirv.size(); i++) {
50+
for (uint64_t i = 0; i < p_spirv.size(); i++) {
5151
RenderingShaderContainer::Shader &shader = shaders.ptrw()[i];
52-
5352
if (debug_info_enabled) {
5453
// Store SPIR-V as is when debug info is required.
55-
shader.code_compressed_bytes = p_spirv[i].spirv;
54+
shader.code_compressed_bytes = p_spirv[i].spirv_data();
5655
shader.code_compression_flags = 0;
5756
shader.code_decompressed_size = 0;
5857
} else {
5958
// Encode into smolv.
59+
Span<uint8_t> spirv = p_spirv[i].spirv().reinterpret<uint8_t>();
6060
smolv::ByteArray smolv_bytes;
61-
bool smolv_encoded = smolv::Encode(p_spirv[i].spirv.ptr(), p_spirv[i].spirv.size(), smolv_bytes, smolv::kEncodeFlagStripDebugInfo);
61+
bool smolv_encoded = smolv::Encode(spirv.ptr(), spirv.size(), smolv_bytes, smolv::kEncodeFlagStripDebugInfo);
6262
ERR_FAIL_COND_V_MSG(!smolv_encoded, false, "Failed to compress SPIR-V into smolv.");
6363

6464
code_bytes.resize(smolv_bytes.size());

drivers/vulkan/rendering_shader_container_vulkan.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class RenderingShaderContainerVulkan : public RenderingShaderContainer {
4747
protected:
4848
virtual uint32_t _format() const override;
4949
virtual uint32_t _format_version() const override;
50-
virtual bool _set_code_from_spirv(const Vector<RenderingDeviceCommons::ShaderStageSPIRVData> &p_spirv) override;
50+
virtual bool _set_code_from_spirv(Span<ReflectedShaderStage> p_spirv) override;
5151

5252
public:
5353
RenderingShaderContainerVulkan(bool p_debug_info_enabled);

editor/export/shader_baker_export_plugin.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -428,15 +428,10 @@ void ShaderBakerExportPlugin::_process_work_item(WorkItem p_work_item) {
428428
Vector<RD::ShaderStageSPIRVData> spirv_data = ShaderRD::compile_stages(p_work_item.stage_sources);
429429
ERR_FAIL_COND_MSG(spirv_data.is_empty(), "Unable to retrieve SPIR-V data for shader");
430430

431-
RD::ShaderReflection shader_refl;
432-
Error err = RenderingDeviceCommons::reflect_spirv(spirv_data, shader_refl);
433-
ERR_FAIL_COND_MSG(err != OK, "Unable to reflect SPIR-V data that was compiled");
434-
435431
Ref<RenderingShaderContainer> shader_container = shader_container_format->create_container();
436-
shader_container->set_from_shader_reflection(p_work_item.shader_name, shader_refl);
437432

438433
// Compile shader binary from SPIR-V.
439-
bool code_compiled = shader_container->set_code_from_spirv(spirv_data);
434+
bool code_compiled = shader_container->set_code_from_spirv(p_work_item.shader_name, spirv_data);
440435
ERR_FAIL_COND_MSG(!code_compiled, vformat("Failed to compile code to native for SPIR-V."));
441436

442437
PackedByteArray shader_bytes = shader_container->to_bytes();

servers/rendering/rendering_device.cpp

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3366,19 +3366,12 @@ String RenderingDevice::_shader_uniform_debug(RID p_shader, int p_set) {
33663366
}
33673367

33683368
Vector<uint8_t> RenderingDevice::shader_compile_binary_from_spirv(const Vector<ShaderStageSPIRVData> &p_spirv, const String &p_shader_name) {
3369-
ShaderReflection shader_refl;
3370-
if (reflect_spirv(p_spirv, shader_refl) != OK) {
3371-
return Vector<uint8_t>();
3372-
}
3373-
33743369
const RenderingShaderContainerFormat &container_format = driver->get_shader_container_format();
33753370
Ref<RenderingShaderContainer> shader_container = container_format.create_container();
33763371
ERR_FAIL_COND_V(shader_container.is_null(), Vector<uint8_t>());
33773372

3378-
shader_container->set_from_shader_reflection(p_shader_name, shader_refl);
3379-
33803373
// Compile shader binary from SPIR-V.
3381-
bool code_compiled = shader_container->set_code_from_spirv(p_spirv);
3374+
bool code_compiled = shader_container->set_code_from_spirv(p_shader_name, p_spirv);
33823375
ERR_FAIL_COND_V_MSG(!code_compiled, Vector<uint8_t>(), vformat("Failed to compile code to native for SPIR-V."));
33833376

33843377
return shader_container->to_bytes();

0 commit comments

Comments
 (0)