diff --git a/CHANGELOG.md b/CHANGELOG.md index 5e9bc2d0f63..4d9a84a7341 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -53,6 +53,27 @@ We have merged the acceleration structure feature into the `RayQuery` feature. T By @Vecvec in [#7913](https://github.com/gfx-rs/wgpu/pull/7913). +#### New `EXPERIMENTAL_PRECOMPILED_SHADERS` API +We have added `Features::EXPERIMENTAL_PRECOMPILED_SHADERS`, replacing existing passthrough types with a unified `CreateShaderModuleDescriptorPassthrough` which allows passing multiple shader codes for different backends. By @SupaMaggie70Incorporated in [#7834](https://github.com/gfx-rs/wgpu/pull/7834) + +Difference for SPIR-V passthrough: +```diff +- device.create_shader_module_passthrough(wgpu::ShaderModuleDescriptorPassthrough::SpirV( +- wgpu::ShaderModuleDescriptorSpirV { +- label: None, +- source: spirv_code, +- }, +- )) ++ device.create_shader_module_passthrough(wgpu::ShaderModuleDescriptorPassthrough { ++ entry_point: "main".into(), ++ label: None, ++ spirv: Some(spirv_code), ++ ..Default::default() +}) +``` +This allows using precompiled shaders without manually checking which backend's code to pass, for example if you have shaders precompiled for both DXIL and SPIR-V. + + ### New Features #### General diff --git a/Cargo.lock b/Cargo.lock index 56128ca6616..2d3f7dcd13d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3191,6 +3191,7 @@ checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" name = "player" version = "26.0.0" dependencies = [ + "bytemuck", "env_logger", "log", "raw-window-handle 0.6.2", diff --git a/deno_webgpu/webidl.rs b/deno_webgpu/webidl.rs index 533da00c830..75a19d372c3 100644 --- a/deno_webgpu/webidl.rs +++ b/deno_webgpu/webidl.rs @@ -419,10 +419,6 @@ pub enum GPUFeatureName { VertexWritableStorage, #[webidl(rename = "clear-texture")] ClearTexture, - #[webidl(rename = "msl-shader-passthrough")] - MslShaderPassthrough, - #[webidl(rename = "spirv-shader-passthrough")] - SpirvShaderPassthrough, #[webidl(rename = "multiview")] Multiview, #[webidl(rename = "vertex-attribute-64-bit")] @@ -435,6 +431,8 @@ pub enum GPUFeatureName { ShaderPrimitiveIndex, #[webidl(rename = "shader-early-depth-test")] ShaderEarlyDepthTest, + #[webidl(rename = "passthrough-shaders")] + PassthroughShaders, } pub fn feature_names_to_features(names: Vec) -> wgpu_types::Features { @@ -482,14 +480,13 @@ pub fn feature_names_to_features(names: Vec) -> wgpu_types::Feat GPUFeatureName::ConservativeRasterization => Features::CONSERVATIVE_RASTERIZATION, GPUFeatureName::VertexWritableStorage => Features::VERTEX_WRITABLE_STORAGE, GPUFeatureName::ClearTexture => Features::CLEAR_TEXTURE, - GPUFeatureName::MslShaderPassthrough => Features::MSL_SHADER_PASSTHROUGH, - GPUFeatureName::SpirvShaderPassthrough => Features::SPIRV_SHADER_PASSTHROUGH, GPUFeatureName::Multiview => Features::MULTIVIEW, GPUFeatureName::VertexAttribute64Bit => Features::VERTEX_ATTRIBUTE_64BIT, GPUFeatureName::ShaderF64 => Features::SHADER_F64, GPUFeatureName::ShaderI16 => Features::SHADER_I16, GPUFeatureName::ShaderPrimitiveIndex => Features::SHADER_PRIMITIVE_INDEX, GPUFeatureName::ShaderEarlyDepthTest => Features::SHADER_EARLY_DEPTH_TEST, + GPUFeatureName::PassthroughShaders => Features::EXPERIMENTAL_PASSTHROUGH_SHADERS, }; features.set(feature, true); } @@ -626,9 +623,6 @@ pub fn features_to_feature_names(features: wgpu_types::Features) -> HashSet HashSet wgpu::Features { - wgpu::Features::EXPERIMENTAL_MESH_SHADER | wgpu::Features::SPIRV_SHADER_PASSTHROUGH + wgpu::Features::EXPERIMENTAL_MESH_SHADER | wgpu::Features::EXPERIMENTAL_PASSTHROUGH_SHADERS } fn required_limits() -> wgpu::Limits { wgpu::Limits::defaults().using_recommended_minimum_mesh_shader_values() diff --git a/player/Cargo.toml b/player/Cargo.toml index aadf4e5e8ff..f29aacf0e94 100644 --- a/player/Cargo.toml +++ b/player/Cargo.toml @@ -26,6 +26,7 @@ log.workspace = true raw-window-handle.workspace = true ron.workspace = true winit = { workspace = true, optional = true } +bytemuck.workspace = true # Non-Webassembly # diff --git a/player/src/lib.rs b/player/src/lib.rs index 5076db0267d..8ba7e13ce1b 100644 --- a/player/src/lib.rs +++ b/player/src/lib.rs @@ -315,6 +315,84 @@ impl GlobalPlay for wgc::global::Global { println!("shader compilation error:\n---{code}\n---\n{e}"); } } + Action::CreateShaderModulePassthrough { + id, + data, + entry_point, + label, + num_workgroups, + runtime_checks, + } => { + let spirv = data.iter().find_map(|a| { + if a.ends_with(".spv") { + let data = fs::read(dir.join(a)).unwrap(); + assert!(data.len() % 4 == 0); + + Some(Cow::Owned(bytemuck::pod_collect_to_vec(&data))) + } else { + None + } + }); + let dxil = data.iter().find_map(|a| { + if a.ends_with(".dxil") { + let vec = std::fs::read(dir.join(a)).unwrap(); + Some(Cow::Owned(vec)) + } else { + None + } + }); + let hlsl = data.iter().find_map(|a| { + if a.ends_with(".hlsl") { + let code = fs::read_to_string(dir.join(a)).unwrap(); + Some(Cow::Owned(code)) + } else { + None + } + }); + let msl = data.iter().find_map(|a| { + if a.ends_with(".msl") { + let code = fs::read_to_string(dir.join(a)).unwrap(); + Some(Cow::Owned(code)) + } else { + None + } + }); + let glsl = data.iter().find_map(|a| { + if a.ends_with(".glsl") { + let code = fs::read_to_string(dir.join(a)).unwrap(); + Some(Cow::Owned(code)) + } else { + None + } + }); + let wgsl = data.iter().find_map(|a| { + if a.ends_with(".wgsl") { + let code = fs::read_to_string(dir.join(a)).unwrap(); + Some(Cow::Owned(code)) + } else { + None + } + }); + let desc = wgt::CreateShaderModuleDescriptorPassthrough { + entry_point, + label, + num_workgroups, + runtime_checks, + + spirv, + dxil, + hlsl, + msl, + glsl, + wgsl, + }; + let (_, error) = unsafe { + self.device_create_shader_module_passthrough(device, &desc, Some(id)) + }; + if let Some(e) = error { + println!("shader compilation error: {e}"); + } + } Action::DestroyShaderModule(id) => { self.shader_module_drop(id); } diff --git a/tests/tests/wgpu-gpu/mesh_shader/mod.rs b/tests/tests/wgpu-gpu/mesh_shader/mod.rs index 8a0bdf4b80f..4dd897129f6 100644 --- a/tests/tests/wgpu-gpu/mesh_shader/mod.rs +++ b/tests/tests/wgpu-gpu/mesh_shader/mod.rs @@ -41,12 +41,12 @@ fn compile_glsl( let output = cmd.wait_with_output().expect("Error waiting for glslc"); assert!(output.status.success()); unsafe { - device.create_shader_module_passthrough(wgpu::ShaderModuleDescriptorPassthrough::SpirV( - wgpu::ShaderModuleDescriptorSpirV { - label: None, - source: wgpu::util::make_spirv_raw(&output.stdout), - }, - )) + device.create_shader_module_passthrough(wgpu::ShaderModuleDescriptorPassthrough { + entry_point: "main".into(), + label: None, + spirv: Some(wgpu::util::make_spirv_raw(&output.stdout)), + ..Default::default() + }) } } @@ -267,7 +267,7 @@ fn default_gpu_test_config(draw_type: DrawType) -> GpuTestConfiguration { .test_features_limits() .features( wgpu::Features::EXPERIMENTAL_MESH_SHADER - | wgpu::Features::SPIRV_SHADER_PASSTHROUGH + | wgpu::Features::EXPERIMENTAL_PASSTHROUGH_SHADERS | match draw_type { DrawType::Standard | DrawType::Indirect => wgpu::Features::empty(), DrawType::MultiIndirect => wgpu::Features::MULTI_DRAW_INDIRECT, diff --git a/wgpu-core/src/device/global.rs b/wgpu-core/src/device/global.rs index 92f10fc07fc..4b35638c48f 100644 --- a/wgpu-core/src/device/global.rs +++ b/wgpu-core/src/device/global.rs @@ -1094,36 +1094,27 @@ impl Global { #[cfg(feature = "trace")] if let Some(ref mut trace) = *device.trace.lock() { - let data = trace.make_binary(desc.trace_binary_ext(), desc.trace_data()); - trace.add(trace::Action::CreateShaderModule { + let mut file_names = Vec::new(); + for (data, ext) in [ + (desc.spirv.as_ref().map(|a| bytemuck::cast_slice(a)), "spv"), + (desc.dxil.as_deref(), "dxil"), + (desc.hlsl.as_ref().map(|a| a.as_bytes()), "hlsl"), + (desc.msl.as_ref().map(|a| a.as_bytes()), "msl"), + (desc.glsl.as_ref().map(|a| a.as_bytes()), "glsl"), + (desc.wgsl.as_ref().map(|a| a.as_bytes()), "wgsl"), + ] { + if let Some(data) = data { + file_names.push(trace.make_binary(ext, data)); + } + } + trace.add(trace::Action::CreateShaderModulePassthrough { id: fid.id(), - desc: match desc { - pipeline::ShaderModuleDescriptorPassthrough::SpirV(inner) => { - pipeline::ShaderModuleDescriptor { - label: inner.label.clone(), - runtime_checks: wgt::ShaderRuntimeChecks::unchecked(), - } - } - pipeline::ShaderModuleDescriptorPassthrough::Msl(inner) => { - pipeline::ShaderModuleDescriptor { - label: inner.label.clone(), - runtime_checks: wgt::ShaderRuntimeChecks::unchecked(), - } - } - pipeline::ShaderModuleDescriptorPassthrough::Dxil(inner) => { - pipeline::ShaderModuleDescriptor { - label: inner.label.clone(), - runtime_checks: wgt::ShaderRuntimeChecks::unchecked(), - } - } - pipeline::ShaderModuleDescriptorPassthrough::Hlsl(inner) => { - pipeline::ShaderModuleDescriptor { - label: inner.label.clone(), - runtime_checks: wgt::ShaderRuntimeChecks::unchecked(), - } - } - }, - data, + data: file_names, + + entry_point: desc.entry_point.clone(), + label: desc.label.clone(), + num_workgroups: desc.num_workgroups, + runtime_checks: desc.runtime_checks, }); }; @@ -1138,7 +1129,7 @@ impl Global { return (id, None); }; - let id = fid.assign(Fallible::Invalid(Arc::new(desc.label().to_string()))); + let id = fid.assign(Fallible::Invalid(Arc::new(desc.label.to_string()))); (id, Some(error)) } diff --git a/wgpu-core/src/device/resource.rs b/wgpu-core/src/device/resource.rs index a7cacea5945..ba894a977aa 100644 --- a/wgpu-core/src/device/resource.rs +++ b/wgpu-core/src/device/resource.rs @@ -2125,39 +2125,59 @@ impl Device { descriptor: &pipeline::ShaderModuleDescriptorPassthrough<'a>, ) -> Result, pipeline::CreateShaderModuleError> { self.check_is_valid()?; - let hal_shader = match descriptor { - pipeline::ShaderModuleDescriptorPassthrough::SpirV(inner) => { - self.require_features(wgt::Features::SPIRV_SHADER_PASSTHROUGH)?; - hal::ShaderInput::SpirV(&inner.source) - } - pipeline::ShaderModuleDescriptorPassthrough::Msl(inner) => { - self.require_features(wgt::Features::MSL_SHADER_PASSTHROUGH)?; - hal::ShaderInput::Msl { - shader: inner.source.to_string(), - entry_point: inner.entry_point.to_string(), - num_workgroups: inner.num_workgroups, - } - } - pipeline::ShaderModuleDescriptorPassthrough::Dxil(inner) => { - self.require_features(wgt::Features::HLSL_DXIL_SHADER_PASSTHROUGH)?; - hal::ShaderInput::Dxil { - shader: inner.source, - entry_point: inner.entry_point.clone(), - num_workgroups: inner.num_workgroups, + self.require_features(wgt::Features::EXPERIMENTAL_PASSTHROUGH_SHADERS)?; + + // TODO: when we get to use if-let chains, this will be a little nicer! + + log::info!("Backend: {}", self.backend()); + let hal_shader = match self.backend() { + wgt::Backend::Vulkan => hal::ShaderInput::SpirV( + descriptor + .spirv + .as_ref() + .ok_or(pipeline::CreateShaderModuleError::NotCompiledForBackend)?, + ), + wgt::Backend::Dx12 => { + if let Some(dxil) = &descriptor.dxil { + hal::ShaderInput::Dxil { + shader: dxil, + entry_point: descriptor.entry_point.clone(), + num_workgroups: descriptor.num_workgroups, + } + } else if let Some(hlsl) = &descriptor.hlsl { + hal::ShaderInput::Hlsl { + shader: hlsl, + entry_point: descriptor.entry_point.clone(), + num_workgroups: descriptor.num_workgroups, + } + } else { + return Err(pipeline::CreateShaderModuleError::NotCompiledForBackend); } } - pipeline::ShaderModuleDescriptorPassthrough::Hlsl(inner) => { - self.require_features(wgt::Features::HLSL_DXIL_SHADER_PASSTHROUGH)?; - hal::ShaderInput::Hlsl { - shader: inner.source, - entry_point: inner.entry_point.clone(), - num_workgroups: inner.num_workgroups, - } + wgt::Backend::Metal => hal::ShaderInput::Msl { + shader: descriptor + .msl + .as_ref() + .ok_or(pipeline::CreateShaderModuleError::NotCompiledForBackend)?, + entry_point: descriptor.entry_point.clone(), + num_workgroups: descriptor.num_workgroups, + }, + wgt::Backend::Gl => hal::ShaderInput::Glsl { + shader: descriptor + .glsl + .as_ref() + .ok_or(pipeline::CreateShaderModuleError::NotCompiledForBackend)?, + entry_point: descriptor.entry_point.clone(), + num_workgroups: descriptor.num_workgroups, + }, + wgt::Backend::Noop => { + return Err(pipeline::CreateShaderModuleError::NotCompiledForBackend) } + wgt::Backend::BrowserWebGpu => unreachable!(), }; let hal_desc = hal::ShaderModuleDescriptor { - label: descriptor.label().to_hal(self.instance_flags), + label: descriptor.label.to_hal(self.instance_flags), runtime_checks: wgt::ShaderRuntimeChecks::unchecked(), }; @@ -2180,7 +2200,7 @@ impl Device { raw: ManuallyDrop::new(raw), device: self.clone(), interface: None, - label: descriptor.label().to_string(), + label: descriptor.label.to_string(), }; Ok(Arc::new(module)) diff --git a/wgpu-core/src/device/trace.rs b/wgpu-core/src/device/trace.rs index 58d26e4b079..80432d5e938 100644 --- a/wgpu-core/src/device/trace.rs +++ b/wgpu-core/src/device/trace.rs @@ -93,6 +93,15 @@ pub enum Action<'a> { desc: crate::pipeline::ShaderModuleDescriptor<'a>, data: FileName, }, + CreateShaderModulePassthrough { + id: id::ShaderModuleId, + data: Vec, + + entry_point: String, + label: crate::Label<'a>, + num_workgroups: (u32, u32, u32), + runtime_checks: wgt::ShaderRuntimeChecks, + }, DestroyShaderModule(id::ShaderModuleId), CreateComputePipeline { id: id::ComputePipelineId, diff --git a/wgpu-core/src/pipeline.rs b/wgpu-core/src/pipeline.rs index 580fdc4d94e..753518f67ad 100644 --- a/wgpu-core/src/pipeline.rs +++ b/wgpu-core/src/pipeline.rs @@ -130,6 +130,8 @@ pub enum CreateShaderModuleError { group: u32, limit: u32, }, + #[error("Generic shader passthrough does not contain any code compatible with this backend.")] + NotCompiledForBackend, } impl WebGpuError for CreateShaderModuleError { @@ -147,6 +149,7 @@ impl WebGpuError for CreateShaderModuleError { Self::ParsingGlsl(..) => return ErrorType::Validation, #[cfg(feature = "spirv")] Self::ParsingSpirV(..) => return ErrorType::Validation, + Self::NotCompiledForBackend => return ErrorType::Validation, }; e.webgpu_error_type() } diff --git a/wgpu-hal/src/dx12/adapter.rs b/wgpu-hal/src/dx12/adapter.rs index 1f57fa69010..88d01103629 100644 --- a/wgpu-hal/src/dx12/adapter.rs +++ b/wgpu-hal/src/dx12/adapter.rs @@ -363,8 +363,8 @@ impl super::Adapter { | wgt::Features::TEXTURE_FORMAT_NV12 | wgt::Features::FLOAT32_FILTERABLE | wgt::Features::TEXTURE_ATOMIC - | wgt::Features::EXTERNAL_TEXTURE - | wgt::Features::HLSL_DXIL_SHADER_PASSTHROUGH; + | wgt::Features::EXPERIMENTAL_PASSTHROUGH_SHADERS + | wgt::Features::EXTERNAL_TEXTURE; //TODO: in order to expose this, we need to run a compute shader // that extract the necessary statistics out of the D3D12 result. diff --git a/wgpu-hal/src/dx12/device.rs b/wgpu-hal/src/dx12/device.rs index d62fc5c6751..34454c9f963 100644 --- a/wgpu-hal/src/dx12/device.rs +++ b/wgpu-hal/src/dx12/device.rs @@ -1777,12 +1777,6 @@ impl crate::Device for super::Device { raw_name, runtime_checks: desc.runtime_checks, }), - crate::ShaderInput::SpirV(_) => { - panic!("SPIRV_SHADER_PASSTHROUGH is not enabled for this backend") - } - crate::ShaderInput::Msl { .. } => { - panic!("MSL_SHADER_PASSTHROUGH is not enabled for this backend") - } crate::ShaderInput::Dxil { shader, entry_point, @@ -1809,6 +1803,11 @@ impl crate::Device for super::Device { raw_name, runtime_checks: desc.runtime_checks, }), + crate::ShaderInput::SpirV(_) + | crate::ShaderInput::Msl { .. } + | crate::ShaderInput::Glsl { .. } => { + unreachable!() + } } } unsafe fn destroy_shader_module(&self, _module: super::ShaderModule) { diff --git a/wgpu-hal/src/gles/device.rs b/wgpu-hal/src/gles/device.rs index 0b1b77e11bc..dda5525c61c 100644 --- a/wgpu-hal/src/gles/device.rs +++ b/wgpu-hal/src/gles/device.rs @@ -222,8 +222,8 @@ impl super::Device { }; let (module, info) = naga::back::pipeline_constants::process_overrides( - &stage.module.naga.module, - &stage.module.naga.info, + &stage.module.source.module, + &stage.module.source.info, Some((naga_stage, stage.entry_point)), stage.constants, ) @@ -463,7 +463,7 @@ impl super::Device { for (stage_idx, stage_items) in push_constant_items.into_iter().enumerate() { for item in stage_items { - let naga_module = &shaders[stage_idx].1.module.naga.module; + let naga_module = &shaders[stage_idx].1.module.source.module; let type_inner = &naga_module.types[item.ty].inner; let location = unsafe { gl.get_uniform_location(program, &item.access_path) }; @@ -1334,16 +1334,15 @@ impl crate::Device for super::Device { self.counters.shader_modules.add(1); Ok(super::ShaderModule { - naga: match shader { - crate::ShaderInput::SpirV(_) => { - panic!("`Features::SPIRV_SHADER_PASSTHROUGH` is not enabled") - } - crate::ShaderInput::Msl { .. } => { - panic!("`Features::MSL_SHADER_PASSTHROUGH` is not enabled") - } + source: match shader { crate::ShaderInput::Naga(naga) => naga, - crate::ShaderInput::Dxil { .. } | crate::ShaderInput::Hlsl { .. } => { - panic!("`Features::HLSL_DXIL_SHADER_PASSTHROUGH` is not enabled") + // The backend doesn't yet expose this feature so it should be fine + crate::ShaderInput::Glsl { .. } => unimplemented!(), + crate::ShaderInput::SpirV(_) + | crate::ShaderInput::Msl { .. } + | crate::ShaderInput::Dxil { .. } + | crate::ShaderInput::Hlsl { .. } => { + unreachable!() } }, label: desc.label.map(|str| str.to_string()), diff --git a/wgpu-hal/src/gles/mod.rs b/wgpu-hal/src/gles/mod.rs index 94416086d2e..b56a851e395 100644 --- a/wgpu-hal/src/gles/mod.rs +++ b/wgpu-hal/src/gles/mod.rs @@ -605,7 +605,7 @@ type ShaderId = u32; #[derive(Debug)] pub struct ShaderModule { - naga: crate::NagaShader, + source: crate::NagaShader, label: Option, id: ShaderId, } diff --git a/wgpu-hal/src/lib.rs b/wgpu-hal/src/lib.rs index e60b07e9a6a..b4255a6c811 100644 --- a/wgpu-hal/src/lib.rs +++ b/wgpu-hal/src/lib.rs @@ -2219,7 +2219,7 @@ impl fmt::Debug for NagaShader { pub enum ShaderInput<'a> { Naga(NagaShader), Msl { - shader: String, + shader: &'a str, entry_point: String, num_workgroups: (u32, u32, u32), }, @@ -2234,6 +2234,11 @@ pub enum ShaderInput<'a> { entry_point: String, num_workgroups: (u32, u32, u32), }, + Glsl { + shader: &'a str, + entry_point: String, + num_workgroups: (u32, u32, u32), + }, } pub struct ShaderModuleDescriptor<'a> { diff --git a/wgpu-hal/src/metal/adapter.rs b/wgpu-hal/src/metal/adapter.rs index d02e38980ac..02dfc0fe601 100644 --- a/wgpu-hal/src/metal/adapter.rs +++ b/wgpu-hal/src/metal/adapter.rs @@ -917,7 +917,6 @@ impl super::PrivateCapabilities { use wgt::Features as F; let mut features = F::empty() - | F::MSL_SHADER_PASSTHROUGH | F::MAPPABLE_PRIMARY_BUFFERS | F::VERTEX_WRITABLE_STORAGE | F::TEXTURE_ADAPTER_SPECIFIC_FORMAT_FEATURES @@ -927,7 +926,8 @@ impl super::PrivateCapabilities { | F::TEXTURE_FORMAT_16BIT_NORM | F::SHADER_F16 | F::DEPTH32FLOAT_STENCIL8 - | F::BGRA8UNORM_STORAGE; + | F::BGRA8UNORM_STORAGE + | F::EXPERIMENTAL_PASSTHROUGH_SHADERS; features.set(F::FLOAT32_FILTERABLE, self.supports_float_filtering); features.set( diff --git a/wgpu-hal/src/metal/device.rs b/wgpu-hal/src/metal/device.rs index fbb166c2723..6af8ad3062d 100644 --- a/wgpu-hal/src/metal/device.rs +++ b/wgpu-hal/src/metal/device.rs @@ -1017,7 +1017,7 @@ impl crate::Device for super::Device { // Obtain the locked device from shared let device = self.shared.device.lock(); let library = device - .new_library_with_source(&source, &options) + .new_library_with_source(source, &options) .map_err(|e| crate::ShaderError::Compilation(format!("MSL: {e:?}")))?; let function = library.get_function(&entry_point, None).map_err(|_| { crate::ShaderError::Compilation(format!( @@ -1035,12 +1035,10 @@ impl crate::Device for super::Device { bounds_checks: desc.runtime_checks, }) } - crate::ShaderInput::SpirV(_) => { - panic!("SPIRV_SHADER_PASSTHROUGH is not enabled for this backend") - } - crate::ShaderInput::Dxil { .. } | crate::ShaderInput::Hlsl { .. } => { - panic!("`Features::HLSL_DXIL_SHADER_PASSTHROUGH` is not enabled for this backend") - } + crate::ShaderInput::SpirV(_) + | crate::ShaderInput::Dxil { .. } + | crate::ShaderInput::Hlsl { .. } + | crate::ShaderInput::Glsl { .. } => unreachable!(), } } diff --git a/wgpu-hal/src/vulkan/adapter.rs b/wgpu-hal/src/vulkan/adapter.rs index bb4e2a9d4ae..f94e1ac3272 100644 --- a/wgpu-hal/src/vulkan/adapter.rs +++ b/wgpu-hal/src/vulkan/adapter.rs @@ -543,7 +543,6 @@ impl PhysicalDeviceFeatures { ) -> (wgt::Features, wgt::DownlevelFlags) { use wgt::{DownlevelFlags as Df, Features as F}; let mut features = F::empty() - | F::SPIRV_SHADER_PASSTHROUGH | F::MAPPABLE_PRIMARY_BUFFERS | F::PUSH_CONSTANTS | F::ADDRESS_MODE_CLAMP_TO_BORDER @@ -555,7 +554,8 @@ impl PhysicalDeviceFeatures { | F::CLEAR_TEXTURE | F::PIPELINE_CACHE | F::SHADER_EARLY_DEPTH_TEST - | F::TEXTURE_ATOMIC; + | F::TEXTURE_ATOMIC + | F::EXPERIMENTAL_PASSTHROUGH_SHADERS; let mut dl_flags = Df::COMPUTE_SHADERS | Df::BASE_VERTEX diff --git a/wgpu-hal/src/vulkan/device.rs b/wgpu-hal/src/vulkan/device.rs index 48be11a124d..42771f5c597 100644 --- a/wgpu-hal/src/vulkan/device.rs +++ b/wgpu-hal/src/vulkan/device.rs @@ -1940,13 +1940,11 @@ impl crate::Device for super::Device { .map_err(|e| crate::ShaderError::Compilation(format!("{e}")))?, ) } - crate::ShaderInput::Msl { .. } => { - panic!("MSL_SHADER_PASSTHROUGH is not enabled for this backend") - } - crate::ShaderInput::Dxil { .. } | crate::ShaderInput::Hlsl { .. } => { - panic!("`Features::HLSL_DXIL_SHADER_PASSTHROUGH` is not enabled") - } - crate::ShaderInput::SpirV(spv) => Cow::Borrowed(spv), + crate::ShaderInput::SpirV(data) => Cow::Borrowed(data), + crate::ShaderInput::Msl { .. } + | crate::ShaderInput::Dxil { .. } + | crate::ShaderInput::Hlsl { .. } + | crate::ShaderInput::Glsl { .. } => unreachable!(), }; let raw = self.create_shader_module_impl(&spv)?; diff --git a/wgpu-types/src/features.rs b/wgpu-types/src/features.rs index af3a920cd61..8f7418b018d 100644 --- a/wgpu-types/src/features.rs +++ b/wgpu-types/src/features.rs @@ -926,29 +926,6 @@ bitflags_array! { /// /// This is a native only feature. const CLEAR_TEXTURE = 1 << 23; - /// Enables creating shader modules from Metal MSL computer shaders (unsafe). - /// - /// Metal data is not parsed or interpreted in any way - /// - /// Supported platforms: - /// - Metal - /// - /// This is a native only feature. - const MSL_SHADER_PASSTHROUGH = 1 << 24; - /// Enables creating shader modules from SPIR-V binary data (unsafe). - /// - /// SPIR-V data is not parsed or interpreted in any way; you can use - /// [`wgpu::make_spirv_raw!`] to check for alignment and magic number when converting from - /// raw bytes. - /// - /// Supported platforms: - /// - Vulkan, in case shader's requested capabilities and extensions agree with - /// Vulkan implementation. - /// - /// This is a native only feature. - /// - /// [`wgpu::make_spirv_raw!`]: https://docs.rs/wgpu/latest/wgpu/macro.include_spirv_raw.html - const SPIRV_SHADER_PASSTHROUGH = 1 << 25; /// Enables multiview render passes and `builtin(view_index)` in vertex shaders. /// /// Supported platforms: @@ -1243,15 +1220,23 @@ bitflags_array! { /// [BlasTriangleGeometrySizeDescriptor::vertex_format]: super::BlasTriangleGeometrySizeDescriptor const EXTENDED_ACCELERATION_STRUCTURE_VERTEX_FORMATS = 1 << 51; - /// Enables creating shader modules from DirectX HLSL or DXIL shaders (unsafe) + /// Enables creating shaders from passthrough with reflection info (unsafe) /// - /// HLSL/DXIL data is not parsed or interpreted in any way + /// Allows using [`Device::create_shader_module_passthrough`]. + /// Shader code isn't parsed or interpreted in any way. It is the user's + /// responsibility to ensure the code and reflection (if passed) are correct. /// - /// Supported platforms: + /// Supported platforms + /// - Vulkan /// - DX12 + /// - Metal + /// - WebGPU /// - /// This is a native only feature. - const HLSL_DXIL_SHADER_PASSTHROUGH = 1 << 52; + /// Ideally, in the future, all platforms will be supported. For more info, see + /// [this comment](https://github.com/gfx-rs/wgpu/issues/3103#issuecomment-2833058367). + /// + /// [`Device::create_shader_module_passthrough`]: https://docs.rs/wgpu/latest/wgpu/struct.Device.html#method.create_shader_module_passthrough + const EXPERIMENTAL_PASSTHROUGH_SHADERS = 1 << 52; } /// Features that are not guaranteed to be supported. diff --git a/wgpu-types/src/lib.rs b/wgpu-types/src/lib.rs index fb71e420262..ea2a09eb62a 100644 --- a/wgpu-types/src/lib.rs +++ b/wgpu-types/src/lib.rs @@ -8032,20 +8032,52 @@ pub enum DeviceLostReason { Destroyed = 1, } -/// Descriptor for creating a shader module. -/// -/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification, -/// only WGSL source code strings are accepted. +/// Descriptor for a shader module given by any of several sources. +/// These shaders are passed through directly to the underlying api. +/// At least one shader type that may be used by the backend must be `Some` or a panic is raised. #[derive(Debug, Clone)] -pub enum CreateShaderModuleDescriptorPassthrough<'a, L> { - /// Passthrough for SPIR-V binaries. - SpirV(ShaderModuleDescriptorSpirV<'a, L>), - /// Passthrough for MSL source code. - Msl(ShaderModuleDescriptorMsl<'a, L>), - /// Passthrough for DXIL compiled with DXC - Dxil(ShaderModuleDescriptorDxil<'a, L>), - /// Passthrough for HLSL - Hlsl(ShaderModuleDescriptorHlsl<'a, L>), +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct CreateShaderModuleDescriptorPassthrough<'a, L> { + /// Entrypoint. Unused for Spir-V. + pub entry_point: String, + /// Debug label of the shader module. This will show up in graphics debuggers for easy identification. + pub label: L, + /// Number of workgroups in each dimension x, y and z. Unused for Spir-V. + pub num_workgroups: (u32, u32, u32), + /// Runtime checks that should be enabled. + pub runtime_checks: ShaderRuntimeChecks, + + /// Binary SPIR-V data, in 4-byte words. + pub spirv: Option>, + /// Shader DXIL source. + pub dxil: Option>, + /// Shader MSL source. + pub msl: Option>, + /// Shader HLSL source. + pub hlsl: Option>, + /// Shader GLSL source (currently unused). + pub glsl: Option>, + /// Shader WGSL source. + pub wgsl: Option>, +} + +// This is so people don't have to fill in fields they don't use, like num_workgroups, +// entry_point, or other shader languages they didn't compile for +impl<'a, L: Default> Default for CreateShaderModuleDescriptorPassthrough<'a, L> { + fn default() -> Self { + Self { + entry_point: "".into(), + label: Default::default(), + num_workgroups: (0, 0, 0), + runtime_checks: ShaderRuntimeChecks::unchecked(), + spirv: None, + dxil: None, + msl: None, + hlsl: None, + glsl: None, + wgsl: None, + } + } } impl<'a, L> CreateShaderModuleDescriptorPassthrough<'a, L> { @@ -8053,134 +8085,46 @@ impl<'a, L> CreateShaderModuleDescriptorPassthrough<'a, L> { pub fn map_label( &self, fun: impl FnOnce(&L) -> K, - ) -> CreateShaderModuleDescriptorPassthrough<'_, K> { - match self { - CreateShaderModuleDescriptorPassthrough::SpirV(inner) => { - CreateShaderModuleDescriptorPassthrough::<'_, K>::SpirV( - ShaderModuleDescriptorSpirV { - label: fun(&inner.label), - source: inner.source.clone(), - }, - ) - } - CreateShaderModuleDescriptorPassthrough::Msl(inner) => { - CreateShaderModuleDescriptorPassthrough::<'_, K>::Msl(ShaderModuleDescriptorMsl { - entry_point: inner.entry_point.clone(), - label: fun(&inner.label), - num_workgroups: inner.num_workgroups, - source: inner.source.clone(), - }) - } - CreateShaderModuleDescriptorPassthrough::Dxil(inner) => { - CreateShaderModuleDescriptorPassthrough::<'_, K>::Dxil(ShaderModuleDescriptorDxil { - entry_point: inner.entry_point.clone(), - label: fun(&inner.label), - num_workgroups: inner.num_workgroups, - source: inner.source, - }) - } - CreateShaderModuleDescriptorPassthrough::Hlsl(inner) => { - CreateShaderModuleDescriptorPassthrough::<'_, K>::Hlsl(ShaderModuleDescriptorHlsl { - entry_point: inner.entry_point.clone(), - label: fun(&inner.label), - num_workgroups: inner.num_workgroups, - source: inner.source, - }) - } - } - } - - /// Returns the label of shader module passthrough descriptor. - pub fn label(&'a self) -> &'a L { - match self { - CreateShaderModuleDescriptorPassthrough::SpirV(inner) => &inner.label, - CreateShaderModuleDescriptorPassthrough::Msl(inner) => &inner.label, - CreateShaderModuleDescriptorPassthrough::Dxil(inner) => &inner.label, - CreateShaderModuleDescriptorPassthrough::Hlsl(inner) => &inner.label, + ) -> CreateShaderModuleDescriptorPassthrough<'a, K> { + CreateShaderModuleDescriptorPassthrough { + entry_point: self.entry_point.clone(), + label: fun(&self.label), + num_workgroups: self.num_workgroups, + runtime_checks: self.runtime_checks, + spirv: self.spirv.clone(), + dxil: self.dxil.clone(), + msl: self.msl.clone(), + hlsl: self.hlsl.clone(), + glsl: self.glsl.clone(), + wgsl: self.wgsl.clone(), } } #[cfg(feature = "trace")] /// Returns the source data for tracing purpose. pub fn trace_data(&self) -> &[u8] { - match self { - CreateShaderModuleDescriptorPassthrough::SpirV(inner) => { - bytemuck::cast_slice(&inner.source) - } - CreateShaderModuleDescriptorPassthrough::Msl(inner) => inner.source.as_bytes(), - CreateShaderModuleDescriptorPassthrough::Dxil(inner) => inner.source, - CreateShaderModuleDescriptorPassthrough::Hlsl(inner) => inner.source.as_bytes(), + if let Some(spirv) = &self.spirv { + bytemuck::cast_slice(spirv) + } else if let Some(msl) = &self.msl { + msl.as_bytes() + } else if let Some(dxil) = &self.dxil { + dxil + } else { + panic!("No binary data provided to `ShaderModuleDescriptorGeneric`") } } #[cfg(feature = "trace")] /// Returns the binary file extension for tracing purpose. pub fn trace_binary_ext(&self) -> &'static str { - match self { - CreateShaderModuleDescriptorPassthrough::SpirV(..) => "spv", - CreateShaderModuleDescriptorPassthrough::Msl(..) => "msl", - CreateShaderModuleDescriptorPassthrough::Dxil(..) => "dxil", - CreateShaderModuleDescriptorPassthrough::Hlsl(..) => "hlsl", + if self.spirv.is_some() { + "spv" + } else if self.msl.is_some() { + "msl" + } else if self.dxil.is_some() { + "dxil" + } else { + panic!("No binary data provided to `ShaderModuleDescriptorGeneric`") } } } - -/// Descriptor for a shader module given by Metal MSL source. -/// -/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification, -/// only WGSL source code strings are accepted. -#[derive(Debug, Clone)] -pub struct ShaderModuleDescriptorMsl<'a, L> { - /// Entrypoint. - pub entry_point: String, - /// Debug label of the shader module. This will show up in graphics debuggers for easy identification. - pub label: L, - /// Number of workgroups in each dimension x, y and z. - pub num_workgroups: (u32, u32, u32), - /// Shader MSL source. - pub source: Cow<'a, str>, -} - -/// Descriptor for a shader module given by DirectX DXIL source. -/// -/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification, -/// only WGSL source code strings are accepted. -#[derive(Debug, Clone)] -pub struct ShaderModuleDescriptorDxil<'a, L> { - /// Entrypoint. - pub entry_point: String, - /// Debug label of the shader module. This will show up in graphics debuggers for easy identification. - pub label: L, - /// Number of workgroups in each dimension x, y and z. - pub num_workgroups: (u32, u32, u32), - /// Shader DXIL source. - pub source: &'a [u8], -} - -/// Descriptor for a shader module given by DirectX HLSL source. -/// -/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification, -/// only WGSL source code strings are accepted. -#[derive(Debug, Clone)] -pub struct ShaderModuleDescriptorHlsl<'a, L> { - /// Entrypoint. - pub entry_point: String, - /// Debug label of the shader module. This will show up in graphics debuggers for easy identification. - pub label: L, - /// Number of workgroups in each dimension x, y and z. - pub num_workgroups: (u32, u32, u32), - /// Shader HLSL source. - pub source: &'a str, -} - -/// Descriptor for a shader module given by SPIR-V binary. -/// -/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification, -/// only WGSL source code strings are accepted. -#[derive(Debug, Clone)] -pub struct ShaderModuleDescriptorSpirV<'a, L> { - /// Debug label of the shader module. This will show up in graphics debuggers for easy identification. - pub label: L, - /// Binary SPIR-V data, in 4-byte words. - pub source: Cow<'a, [u32]>, -} diff --git a/wgpu/src/api/shader_module.rs b/wgpu/src/api/shader_module.rs index c481de6218a..b2aad03ab02 100644 --- a/wgpu/src/api/shader_module.rs +++ b/wgpu/src/api/shader_module.rs @@ -228,34 +228,10 @@ pub struct ShaderModuleDescriptor<'a> { } static_assertions::assert_impl_all!(ShaderModuleDescriptor<'_>: Send, Sync); -/// Descriptor for a shader module that will bypass wgpu's shader tooling, for use with -/// [`Device::create_shader_module_passthrough`]. +/// Descriptor for a shader module given by any of several sources. +/// At least one of the shader types that may be used by the backend must be `Some` /// /// This type is unique to the Rust API of `wgpu`. In the WebGPU specification, /// only WGSL source code strings are accepted. pub type ShaderModuleDescriptorPassthrough<'a> = wgt::CreateShaderModuleDescriptorPassthrough<'a, Label<'a>>; - -/// Descriptor for a shader module given by Metal MSL source. -/// -/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification, -/// only WGSL source code strings are accepted. -pub type ShaderModuleDescriptorMsl<'a> = wgt::ShaderModuleDescriptorMsl<'a, Label<'a>>; - -/// Descriptor for a shader module given by SPIR-V binary. -/// -/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification, -/// only WGSL source code strings are accepted. -pub type ShaderModuleDescriptorSpirV<'a> = wgt::ShaderModuleDescriptorSpirV<'a, Label<'a>>; - -/// Descriptor for a shader module given by DirectX HLSL source. -/// -/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification, -/// only WGSL source code strings are accepted. -pub type ShaderModuleDescriptorHlsl<'a> = wgt::ShaderModuleDescriptorHlsl<'a, Label<'a>>; - -/// Descriptor for a shader module given by DirectX DXIL source. -/// -/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification, -/// only WGSL source code strings are accepted. -pub type ShaderModuleDescriptorDxil<'a> = wgt::ShaderModuleDescriptorDxil<'a, Label<'a>>; diff --git a/wgpu/src/backend/webgpu.rs b/wgpu/src/backend/webgpu.rs index 0bd973cdb63..603922f86a8 100644 --- a/wgpu/src/backend/webgpu.rs +++ b/wgpu/src/backend/webgpu.rs @@ -1862,9 +1862,43 @@ impl dispatch::DeviceInterface for WebDevice { unsafe fn create_shader_module_passthrough( &self, - _desc: &crate::ShaderModuleDescriptorPassthrough<'_>, + desc: &crate::ShaderModuleDescriptorPassthrough<'_>, ) -> dispatch::DispatchShaderModule { - unreachable!("No XXX_SHADER_PASSTHROUGH feature enabled for this backend") + let shader_module_result = if let Some(ref code) = desc.wgsl { + let shader_module = webgpu_sys::GpuShaderModuleDescriptor::new(code); + Ok(( + shader_module, + WebShaderCompilationInfo::Wgsl { + source: code.to_string(), + }, + )) + } else { + Err(crate::CompilationInfo { + messages: vec![crate::CompilationMessage { + message: + "Passthrough shader not compiled for WGSL on WebGPU backend (WGPU error)" + .to_string(), + location: None, + message_type: crate::CompilationMessageType::Error, + }], + }) + }; + let (descriptor, compilation_info) = match shader_module_result { + Ok(v) => v, + Err(compilation_info) => ( + webgpu_sys::GpuShaderModuleDescriptor::new(""), + WebShaderCompilationInfo::Transformed { compilation_info }, + ), + }; + if let Some(label) = desc.label { + descriptor.set_label(label); + } + WebShaderModule { + module: self.inner.create_shader_module(&descriptor), + compilation_info, + ident: crate::cmp::Identifier::create(), + } + .into() } fn create_bind_group_layout( diff --git a/wgpu/src/backend/wgpu_core.rs b/wgpu/src/backend/wgpu_core.rs index 5b96225cdcb..812cd5276a4 100644 --- a/wgpu/src/backend/wgpu_core.rs +++ b/wgpu/src/backend/wgpu_core.rs @@ -1088,7 +1088,7 @@ impl dispatch::DeviceInterface for CoreDevice { self.context.handle_error( &self.error_sink, cause.clone(), - desc.label().as_deref(), + desc.label.as_deref(), "Device::create_shader_module_passthrough", ); CompilationInfo::from(cause) diff --git a/wgpu/src/macros.rs b/wgpu/src/macros.rs index c766dfb2178..537756adb92 100644 --- a/wgpu/src/macros.rs +++ b/wgpu/src/macros.rs @@ -96,22 +96,31 @@ macro_rules! include_spirv { }; } -/// Macro to load raw SPIR-V data statically, for use with [`Features::SPIRV_SHADER_PASSTHROUGH`]. +/// Macro to load raw SPIR-V data statically, for use with [`Features::EXPERIMENTAL_PASSTHROUGH_SHADERS`]. /// /// It ensures the word alignment as well as the magic number. /// -/// [`Features::SPIRV_SHADER_PASSTHROUGH`]: crate::Features::SPIRV_SHADER_PASSTHROUGH +/// [`Features::EXPERIMENTAL_PASSTHROUGH_SHADERS`]: crate::Features::EXPERIMENTAL_PASSTHROUGH_SHADERS #[macro_export] macro_rules! include_spirv_raw { ($($token:tt)*) => { { //log::info!("including '{}'", $($token)*); - $crate::ShaderModuleDescriptorPassthrough::SpirV( - $crate::ShaderModuleDescriptorSpirV { - label: $crate::__macro_helpers::Some($($token)*), - source: $crate::util::make_spirv_raw($crate::__macro_helpers::include_bytes!($($token)*)), - } - ) + $crate::ShaderModuleDescriptorPassthrough { + label: $crate::__macro_helpers::Some($($token)*), + spirv: Some($crate::util::make_spirv_raw($crate::__macro_helpers::include_bytes!($($token)*))), + + entry_point: "".to_owned(), + // This is unused for SPIR-V + num_workgroups: (0, 0, 0), + reflection: None, + shader_runtime_checks: $crate::ShaderRuntimeChecks::unchecked(), + dxil: None, + msl: None, + hlsl: None, + glsl: None, + wgsl: None, + } } }; }