diff --git a/wgpu-hal/src/metal/device.rs b/wgpu-hal/src/metal/device.rs index d034be31f3..dd5d05b6d5 100644 --- a/wgpu-hal/src/metal/device.rs +++ b/wgpu-hal/src/metal/device.rs @@ -128,176 +128,194 @@ impl super::Device { primitive_class: MTLPrimitiveTopologyClass, naga_stage: naga::ShaderStage, ) -> Result { - let naga_shader = if let ShaderModuleSource::Naga(naga) = &stage.module.source { - naga - } else { - panic!("load_shader required a naga shader"); - }; - let stage_bit = map_naga_stage(naga_stage); - let (module, module_info) = naga::back::pipeline_constants::process_overrides( - &naga_shader.module, - &naga_shader.info, - Some((naga_stage, stage.entry_point)), - stage.constants, - ) - .map_err(|e| crate::PipelineError::PipelineConstants(stage_bit, format!("MSL: {e:?}")))?; - - let ep_resources = &layout.per_stage_map[naga_stage]; - - let bounds_check_policy = if stage.module.bounds_checks.bounds_checks { - naga::proc::BoundsCheckPolicy::Restrict - } else { - naga::proc::BoundsCheckPolicy::Unchecked - }; + match stage.module.source { + ShaderModuleSource::Naga(ref naga_shader) => { + let stage_bit = map_naga_stage(naga_stage); + let (module, module_info) = naga::back::pipeline_constants::process_overrides( + &naga_shader.module, + &naga_shader.info, + Some((naga_stage, stage.entry_point)), + stage.constants, + ) + .map_err(|e| { + crate::PipelineError::PipelineConstants(stage_bit, format!("MSL: {e:?}")) + })?; - let options = naga::back::msl::Options { - lang_version: match self.shared.private_caps.msl_version { - MTLLanguageVersion::V1_0 => (1, 0), - MTLLanguageVersion::V1_1 => (1, 1), - MTLLanguageVersion::V1_2 => (1, 2), - MTLLanguageVersion::V2_0 => (2, 0), - MTLLanguageVersion::V2_1 => (2, 1), - MTLLanguageVersion::V2_2 => (2, 2), - MTLLanguageVersion::V2_3 => (2, 3), - MTLLanguageVersion::V2_4 => (2, 4), - MTLLanguageVersion::V3_0 => (3, 0), - MTLLanguageVersion::V3_1 => (3, 1), - }, - inline_samplers: Default::default(), - spirv_cross_compatibility: false, - fake_missing_bindings: false, - per_entry_point_map: naga::back::msl::EntryPointResourceMap::from([( - stage.entry_point.to_owned(), - ep_resources.clone(), - )]), - bounds_check_policies: naga::proc::BoundsCheckPolicies { - index: bounds_check_policy, - buffer: bounds_check_policy, - image_load: bounds_check_policy, - // TODO: support bounds checks on binding arrays - binding_array: naga::proc::BoundsCheckPolicy::Unchecked, - }, - zero_initialize_workgroup_memory: stage.zero_initialize_workgroup_memory, - force_loop_bounding: stage.module.bounds_checks.force_loop_bounding, - }; + let ep_resources = &layout.per_stage_map[naga_stage]; - let pipeline_options = naga::back::msl::PipelineOptions { - entry_point: Some((naga_stage, stage.entry_point.to_owned())), - allow_and_force_point_size: match primitive_class { - MTLPrimitiveTopologyClass::Point => true, - _ => false, - }, - vertex_pulling_transform: true, - vertex_buffer_mappings: vertex_buffer_mappings.to_vec(), - }; + let bounds_check_policy = if stage.module.bounds_checks.bounds_checks { + naga::proc::BoundsCheckPolicy::Restrict + } else { + naga::proc::BoundsCheckPolicy::Unchecked + }; - let (source, info) = - naga::back::msl::write_string(&module, &module_info, &options, &pipeline_options) - .map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("MSL: {e:?}")))?; + let options = naga::back::msl::Options { + lang_version: match self.shared.private_caps.msl_version { + MTLLanguageVersion::V1_0 => (1, 0), + MTLLanguageVersion::V1_1 => (1, 1), + MTLLanguageVersion::V1_2 => (1, 2), + MTLLanguageVersion::V2_0 => (2, 0), + MTLLanguageVersion::V2_1 => (2, 1), + MTLLanguageVersion::V2_2 => (2, 2), + MTLLanguageVersion::V2_3 => (2, 3), + MTLLanguageVersion::V2_4 => (2, 4), + MTLLanguageVersion::V3_0 => (3, 0), + MTLLanguageVersion::V3_1 => (3, 1), + }, + inline_samplers: Default::default(), + spirv_cross_compatibility: false, + fake_missing_bindings: false, + per_entry_point_map: naga::back::msl::EntryPointResourceMap::from([( + stage.entry_point.to_owned(), + ep_resources.clone(), + )]), + bounds_check_policies: naga::proc::BoundsCheckPolicies { + index: bounds_check_policy, + buffer: bounds_check_policy, + image_load: bounds_check_policy, + // TODO: support bounds checks on binding arrays + binding_array: naga::proc::BoundsCheckPolicy::Unchecked, + }, + zero_initialize_workgroup_memory: stage.zero_initialize_workgroup_memory, + force_loop_bounding: stage.module.bounds_checks.force_loop_bounding, + }; - log::debug!( - "Naga generated shader for entry point '{}' and stage {:?}\n{}", - stage.entry_point, - naga_stage, - &source - ); + let pipeline_options = naga::back::msl::PipelineOptions { + entry_point: Some((naga_stage, stage.entry_point.to_owned())), + allow_and_force_point_size: match primitive_class { + MTLPrimitiveTopologyClass::Point => true, + _ => false, + }, + vertex_pulling_transform: true, + vertex_buffer_mappings: vertex_buffer_mappings.to_vec(), + }; - let options = metal::CompileOptions::new(); - options.set_language_version(self.shared.private_caps.msl_version); + let (source, info) = naga::back::msl::write_string( + &module, + &module_info, + &options, + &pipeline_options, + ) + .map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("MSL: {e:?}")))?; - if self.shared.private_caps.supports_preserve_invariance { - options.set_preserve_invariance(true); - } + log::debug!( + "Naga generated shader for entry point '{}' and stage {:?}\n{}", + stage.entry_point, + naga_stage, + &source + ); - let library = self - .shared - .device - .lock() - .new_library_with_source(source.as_ref(), &options) - .map_err(|err| { - log::warn!("Naga generated shader:\n{source}"); - crate::PipelineError::Linkage(stage_bit, format!("Metal: {err}")) - })?; - - let ep_index = module - .entry_points - .iter() - .position(|ep| ep.stage == naga_stage && ep.name == stage.entry_point) - .ok_or(crate::PipelineError::EntryPoint(naga_stage))?; - let ep = &module.entry_points[ep_index]; - let translated_ep_name = info.entry_point_names[0] - .as_ref() - .map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("{e}")))?; - - let wg_size = MTLSize { - width: ep.workgroup_size[0] as _, - height: ep.workgroup_size[1] as _, - depth: ep.workgroup_size[2] as _, - }; + let options = metal::CompileOptions::new(); + options.set_language_version(self.shared.private_caps.msl_version); - let function = library - .get_function(translated_ep_name, None) - .map_err(|e| { - log::error!("get_function: {e:?}"); - crate::PipelineError::EntryPoint(naga_stage) - })?; - - // collect sizes indices, immutable buffers, and work group memory sizes - let ep_info = &module_info.get_entry_point(ep_index); - let mut wg_memory_sizes = Vec::new(); - let mut sized_bindings = Vec::new(); - let mut immutable_buffer_mask = 0; - for (var_handle, var) in module.global_variables.iter() { - match var.space { - naga::AddressSpace::WorkGroup => { - if !ep_info[var_handle].is_empty() { - let size = module.types[var.ty].inner.size(module.to_ctx()); - wg_memory_sizes.push(size); - } + if self.shared.private_caps.supports_preserve_invariance { + options.set_preserve_invariance(true); } - naga::AddressSpace::Uniform | naga::AddressSpace::Storage { .. } => { - let br = match var.binding { - Some(br) => br, - None => continue, - }; - let storage_access_store = match var.space { - naga::AddressSpace::Storage { access } => { - access.contains(naga::StorageAccess::STORE) + + let library = self + .shared + .device + .lock() + .new_library_with_source(source.as_ref(), &options) + .map_err(|err| { + log::warn!("Naga generated shader:\n{source}"); + crate::PipelineError::Linkage(stage_bit, format!("Metal: {err}")) + })?; + + let ep_index = module + .entry_points + .iter() + .position(|ep| ep.stage == naga_stage && ep.name == stage.entry_point) + .ok_or(crate::PipelineError::EntryPoint(naga_stage))?; + let ep = &module.entry_points[ep_index]; + let translated_ep_name = info.entry_point_names[0] + .as_ref() + .map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("{e}")))?; + + let wg_size = MTLSize { + width: ep.workgroup_size[0] as _, + height: ep.workgroup_size[1] as _, + depth: ep.workgroup_size[2] as _, + }; + + let function = library + .get_function(translated_ep_name, None) + .map_err(|e| { + log::error!("get_function: {e:?}"); + crate::PipelineError::EntryPoint(naga_stage) + })?; + + // collect sizes indices, immutable buffers, and work group memory sizes + let ep_info = &module_info.get_entry_point(ep_index); + let mut wg_memory_sizes = Vec::new(); + let mut sized_bindings = Vec::new(); + let mut immutable_buffer_mask = 0; + for (var_handle, var) in module.global_variables.iter() { + match var.space { + naga::AddressSpace::WorkGroup => { + if !ep_info[var_handle].is_empty() { + let size = module.types[var.ty].inner.size(module.to_ctx()); + wg_memory_sizes.push(size); + } } - _ => false, - }; + naga::AddressSpace::Uniform | naga::AddressSpace::Storage { .. } => { + let br = match var.binding { + Some(br) => br, + None => continue, + }; + let storage_access_store = match var.space { + naga::AddressSpace::Storage { access } => { + access.contains(naga::StorageAccess::STORE) + } + _ => false, + }; - // check for an immutable buffer - if !ep_info[var_handle].is_empty() && !storage_access_store { - let slot = ep_resources.resources[&br].buffer.unwrap(); - immutable_buffer_mask |= 1 << slot; - } + // check for an immutable buffer + if !ep_info[var_handle].is_empty() && !storage_access_store { + let slot = ep_resources.resources[&br].buffer.unwrap(); + immutable_buffer_mask |= 1 << slot; + } - let mut dynamic_array_container_ty = var.ty; - if let naga::TypeInner::Struct { ref members, .. } = module.types[var.ty].inner - { - dynamic_array_container_ty = members.last().unwrap().ty; - } - if let naga::TypeInner::Array { - size: naga::ArraySize::Dynamic, - .. - } = module.types[dynamic_array_container_ty].inner - { - sized_bindings.push(br); + let mut dynamic_array_container_ty = var.ty; + if let naga::TypeInner::Struct { ref members, .. } = + module.types[var.ty].inner + { + dynamic_array_container_ty = members.last().unwrap().ty; + } + if let naga::TypeInner::Array { + size: naga::ArraySize::Dynamic, + .. + } = module.types[dynamic_array_container_ty].inner + { + sized_bindings.push(br); + } + } + _ => {} } } - _ => {} + + Ok(CompiledShader { + library, + function, + wg_size, + wg_memory_sizes, + sized_bindings, + immutable_buffer_mask, + }) } + ShaderModuleSource::Passthrough(ref shader) => Ok(CompiledShader { + library: shader.library.clone(), + function: shader.function.clone(), + wg_size: MTLSize { + width: shader.num_workgroups.0 as u64, + height: shader.num_workgroups.1 as u64, + depth: shader.num_workgroups.2 as u64, + }, + wg_memory_sizes: vec![], + sized_bindings: vec![], + immutable_buffer_mask: 0, + }), } - - Ok(CompiledShader { - library, - function, - wg_size, - wg_memory_sizes, - sized_bindings, - immutable_buffer_mask, - }) } fn set_buffers_mutability(