From 45fbacc17f24b83b29fab9b3c1e1472d32edd4f6 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Sat, 23 Aug 2025 23:21:54 -0500 Subject: [PATCH 01/16] Wait did I break it --- wgpu-hal/src/metal/adapter.rs | 12 +++++++++--- wgpu-hal/src/metal/command.rs | 32 ++++++++++++++++++++++++++------ wgpu-hal/src/metal/device.rs | 28 +++++++++++++--------------- wgpu-hal/src/metal/mod.rs | 21 ++++++++++++++------- 4 files changed, 62 insertions(+), 31 deletions(-) diff --git a/wgpu-hal/src/metal/adapter.rs b/wgpu-hal/src/metal/adapter.rs index 02dfc0fe601..9517f0b4dd6 100644 --- a/wgpu-hal/src/metal/adapter.rs +++ b/wgpu-hal/src/metal/adapter.rs @@ -606,6 +606,8 @@ impl super::PrivateCapabilities { } let argument_buffers = device.argument_buffers_support(); + let mesh_shaders = device.supports_family(MTLGPUFamily::Apple7) + || device.supports_family(MTLGPUFamily::Mac2); Self { family_check, @@ -902,6 +904,7 @@ impl super::PrivateCapabilities { && (device.supports_family(MTLGPUFamily::Apple7) || device.supports_family(MTLGPUFamily::Mac2)), supports_shared_event: version.at_least((10, 14), (12, 0), os_is_mac), + mesh_shaders, } } @@ -1003,6 +1006,8 @@ impl super::PrivateCapabilities { features.insert(F::SUBGROUP | F::SUBGROUP_BARRIER); } + features.set(F::EXPERIMENTAL_MESH_SHADER, self.mesh_shaders); + features } @@ -1079,10 +1084,11 @@ impl super::PrivateCapabilities { max_buffer_size: self.max_buffer_size, max_non_sampler_bindings: u32::MAX, - max_task_workgroup_total_count: 0, - max_task_workgroups_per_dimension: 0, + // See https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf, Maximum threadgroups per mesh shader grid + max_task_workgroup_total_count: 1024, + max_task_workgroups_per_dimension: 1024, max_mesh_multiview_count: 0, - max_mesh_output_layers: 0, + max_mesh_output_layers: self.max_texture_layers as u32, max_blas_primitive_count: 0, // When added: 2^28 from https://developer.apple.com/documentation/metal/mtlaccelerationstructureusage/extendedlimits max_blas_geometry_count: 0, // When added: 2^24 diff --git a/wgpu-hal/src/metal/command.rs b/wgpu-hal/src/metal/command.rs index 72a799a0275..2b66343c478 100644 --- a/wgpu-hal/src/metal/command.rs +++ b/wgpu-hal/src/metal/command.rs @@ -906,11 +906,22 @@ impl crate::CommandEncoder for super::CommandEncoder { unsafe fn set_render_pipeline(&mut self, pipeline: &super::RenderPipeline) { self.state.raw_primitive_type = pipeline.raw_primitive_type; - self.state.stage_infos.vs.assign_from(&pipeline.vs_info); + match pipeline.vs_info { + Some(ref info) => self.state.stage_infos.vs.assign_from(info), + None => self.state.stage_infos.vs.clear(), + } match pipeline.fs_info { Some(ref info) => self.state.stage_infos.fs.assign_from(info), None => self.state.stage_infos.fs.clear(), } + match pipeline.ts_info { + Some(ref info) => self.state.stage_infos.ts.assign_from(info), + None => self.state.stage_infos.vs.clear(), + } + match pipeline.ms_info { + Some(ref info) => self.state.stage_infos.ms.assign_from(info), + None => self.state.stage_infos.fs.clear(), + } let encoder = self.state.render.as_ref().unwrap(); encoder.set_render_pipeline_state(&pipeline.raw); @@ -937,7 +948,7 @@ impl crate::CommandEncoder for super::CommandEncoder { ); } } - if pipeline.fs_lib.is_some() { + if pipeline.fs_info.is_some() { if let Some((index, sizes)) = self .state .make_sizes_buffer_update(naga::ShaderStage::Fragment, &mut self.temp.binding_sizes) @@ -1111,11 +1122,20 @@ impl crate::CommandEncoder for super::CommandEncoder { unsafe fn draw_mesh_tasks( &mut self, - _group_count_x: u32, - _group_count_y: u32, - _group_count_z: u32, + group_count_x: u32, + group_count_y: u32, + group_count_z: u32, ) { - unreachable!() + let encoder = self.state.render.as_ref().unwrap(); + encoder.draw_mesh_threadgroups( + MTLSize { + width: group_count_x as u64, + height: group_count_y as u64, + depth: group_count_z as u64, + }, + todo!(), + todo!(), + ); } unsafe fn draw_indirect( diff --git a/wgpu-hal/src/metal/device.rs b/wgpu-hal/src/metal/device.rs index 6af8ad3062d..97878960a36 100644 --- a/wgpu-hal/src/metal/device.rs +++ b/wgpu-hal/src/metal/device.rs @@ -1078,7 +1078,7 @@ impl crate::Device for super::Device { conv::map_primitive_topology(desc.primitive.topology); // Vertex shader - let (vs_lib, vs_info) = { + let vs_info = { let mut vertex_buffer_mappings = Vec::::new(); for (i, vbl) in desc_vertex_buffers.iter().enumerate() { let mut attributes = Vec::::new(); @@ -1124,18 +1124,17 @@ impl crate::Device for super::Device { ); } - let info = super::PipelineStageInfo { + super::PipelineStageInfo { push_constants: desc.layout.push_constants_infos.vs, sizes_slot: desc.layout.per_stage_map.vs.sizes_buffer, sized_bindings: vs.sized_bindings, vertex_buffer_mappings, - }; - - (vs.library, info) + library: Some(vs.library), + } }; // Fragment shader - let (fs_lib, fs_info) = match desc.fragment_stage { + let fs_info = match desc.fragment_stage { Some(ref stage) => { let fs = self.load_shader( stage, @@ -1153,14 +1152,13 @@ impl crate::Device for super::Device { ); } - let info = super::PipelineStageInfo { + Some(super::PipelineStageInfo { push_constants: desc.layout.push_constants_infos.fs, sizes_slot: desc.layout.per_stage_map.fs.sizes_buffer, sized_bindings: fs.sized_bindings, vertex_buffer_mappings: vec![], - }; - - (Some(fs.library), Some(info)) + library: Some(fs.library), + }) } None => { // TODO: This is a workaround for what appears to be a Metal validation bug @@ -1168,7 +1166,7 @@ impl crate::Device for super::Device { if desc.color_targets.is_empty() && desc.depth_stencil.is_none() { descriptor.set_depth_attachment_pixel_format(MTLPixelFormat::Depth32Float); } - (None, None) + None } }; @@ -1302,10 +1300,10 @@ impl crate::Device for super::Device { Ok(super::RenderPipeline { raw, - vs_lib, - fs_lib, - vs_info, + vs_info: Some(vs_info), fs_info, + ts_info: None, + ms_info: None, raw_primitive_type, raw_triangle_fill_mode, raw_front_winding: conv::map_winding(desc.primitive.front_face), @@ -1373,6 +1371,7 @@ impl crate::Device for super::Device { } let cs_info = super::PipelineStageInfo { + library: Some(cs.library), push_constants: desc.layout.push_constants_infos.cs, sizes_slot: desc.layout.per_stage_map.cs.sizes_buffer, sized_bindings: cs.sized_bindings, @@ -1400,7 +1399,6 @@ impl crate::Device for super::Device { Ok(super::ComputePipeline { raw, cs_info, - cs_lib: cs.library, work_group_size: cs.wg_size, work_group_memory_sizes: cs.wg_memory_sizes, }) diff --git a/wgpu-hal/src/metal/mod.rs b/wgpu-hal/src/metal/mod.rs index 00223b2f778..ec4ae11cdef 100644 --- a/wgpu-hal/src/metal/mod.rs +++ b/wgpu-hal/src/metal/mod.rs @@ -300,6 +300,7 @@ struct PrivateCapabilities { int64_atomics: bool, float_atomics: bool, supports_shared_event: bool, + mesh_shaders: bool, } #[derive(Clone, Debug)] @@ -604,12 +605,16 @@ struct MultiStageData { vs: T, fs: T, cs: T, + ts: T, + ms: T, } const NAGA_STAGES: MultiStageData = MultiStageData { vs: naga::ShaderStage::Vertex, fs: naga::ShaderStage::Fragment, cs: naga::ShaderStage::Compute, + ts: naga::ShaderStage::Task, + ms: naga::ShaderStage::Mesh, }; impl ops::Index for MultiStageData { @@ -630,6 +635,8 @@ impl MultiStageData { vs: fun(&self.vs), fs: fun(&self.fs), cs: fun(&self.cs), + ts: fun(&self.ts), + ms: fun(&self.ms), } } fn map(self, fun: impl Fn(T) -> Y) -> MultiStageData { @@ -637,6 +644,8 @@ impl MultiStageData { vs: fun(self.vs), fs: fun(self.fs), cs: fun(self.cs), + ts: fun(self.ts), + ms: fun(self.ms), } } fn iter<'a>(&'a self) -> impl Iterator { @@ -811,6 +820,8 @@ impl crate::DynShaderModule for ShaderModule {} #[derive(Debug, Default)] struct PipelineStageInfo { + #[allow(dead_code)] + library: Option, push_constants: Option, /// The buffer argument table index at which we pass runtime-sized arrays' buffer sizes. @@ -849,12 +860,10 @@ impl PipelineStageInfo { #[derive(Debug)] pub struct RenderPipeline { raw: metal::RenderPipelineState, - #[allow(dead_code)] - vs_lib: metal::Library, - #[allow(dead_code)] - fs_lib: Option, - vs_info: PipelineStageInfo, + vs_info: Option, fs_info: Option, + ts_info: Option, + ms_info: Option, raw_primitive_type: MTLPrimitiveType, raw_triangle_fill_mode: MTLTriangleFillMode, raw_front_winding: MTLWinding, @@ -871,8 +880,6 @@ impl crate::DynRenderPipeline for RenderPipeline {} #[derive(Debug)] pub struct ComputePipeline { raw: metal::ComputePipelineState, - #[allow(dead_code)] - cs_lib: metal::Library, cs_info: PipelineStageInfo, work_group_size: MTLSize, work_group_memory_sizes: Vec, From 611c01a4566a3e6bb48dbf599df063db2b2b6449 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Sun, 24 Aug 2025 00:11:50 -0500 Subject: [PATCH 02/16] More work --- wgpu-hal/src/metal/command.rs | 59 +++-- wgpu-hal/src/metal/device.rs | 393 +++++++++++++++++++++------------- wgpu-hal/src/metal/mod.rs | 5 +- 3 files changed, 299 insertions(+), 158 deletions(-) diff --git a/wgpu-hal/src/metal/command.rs b/wgpu-hal/src/metal/command.rs index 2b66343c478..37beb41a9a3 100644 --- a/wgpu-hal/src/metal/command.rs +++ b/wgpu-hal/src/metal/command.rs @@ -21,7 +21,6 @@ impl Default for super::CommandState { compute: None, raw_primitive_type: MTLPrimitiveType::Point, index: None, - raw_wg_size: MTLSize::new(0, 0, 0), stage_infos: Default::default(), storage_buffer_length_map: Default::default(), vertex_buffer_size_map: Default::default(), @@ -936,7 +935,7 @@ impl crate::CommandEncoder for super::CommandEncoder { encoder.set_depth_bias(bias.constant as f32, bias.slope_scale, bias.clamp); } - { + if pipeline.vs_info.is_some() { if let Some((index, sizes)) = self .state .make_sizes_buffer_update(naga::ShaderStage::Vertex, &mut self.temp.binding_sizes) @@ -960,6 +959,30 @@ impl crate::CommandEncoder for super::CommandEncoder { ); } } + if pipeline.ts_info.is_some() { + if let Some((index, sizes)) = self + .state + .make_sizes_buffer_update(naga::ShaderStage::Task, &mut self.temp.binding_sizes) + { + encoder.set_object_bytes( + index as _, + (sizes.len() * WORD_SIZE) as u64, + sizes.as_ptr().cast(), + ); + } + } + if pipeline.ms_info.is_some() { + if let Some((index, sizes)) = self + .state + .make_sizes_buffer_update(naga::ShaderStage::Mesh, &mut self.temp.binding_sizes) + { + encoder.set_mesh_bytes( + index as _, + (sizes.len() * WORD_SIZE) as u64, + sizes.as_ptr().cast(), + ); + } + } } unsafe fn set_index_buffer<'a>( @@ -1133,8 +1156,8 @@ impl crate::CommandEncoder for super::CommandEncoder { height: group_count_y as u64, depth: group_count_z as u64, }, - todo!(), - todo!(), + self.state.stage_infos.ts.raw_wg_size, + self.state.stage_infos.ms.raw_wg_size, ); } @@ -1174,11 +1197,20 @@ impl crate::CommandEncoder for super::CommandEncoder { unsafe fn draw_mesh_tasks_indirect( &mut self, - _buffer: &::Buffer, - _offset: wgt::BufferAddress, - _draw_count: u32, + buffer: &::Buffer, + mut offset: wgt::BufferAddress, + draw_count: u32, ) { - unreachable!() + let encoder = self.state.render.as_ref().unwrap(); + for _ in 0..draw_count { + encoder.draw_mesh_threadgroups_with_indirect_buffer( + &buffer.raw, + offset, + self.state.stage_infos.ts.raw_wg_size, + self.state.stage_infos.ms.raw_wg_size, + ); + offset += size_of::() as wgt::BufferAddress; + } } unsafe fn draw_indirect_count( @@ -1210,7 +1242,7 @@ impl crate::CommandEncoder for super::CommandEncoder { _count_offset: wgt::BufferAddress, _max_count: u32, ) { - unreachable!() + //TODO } // compute @@ -1286,7 +1318,6 @@ impl crate::CommandEncoder for super::CommandEncoder { } unsafe fn set_compute_pipeline(&mut self, pipeline: &super::ComputePipeline) { - self.state.raw_wg_size = pipeline.work_group_size; self.state.stage_infos.cs.assign_from(&pipeline.cs_info); let encoder = self.state.compute.as_ref().unwrap(); @@ -1330,13 +1361,17 @@ impl crate::CommandEncoder for super::CommandEncoder { height: count[1] as u64, depth: count[2] as u64, }; - encoder.dispatch_thread_groups(raw_count, self.state.raw_wg_size); + encoder.dispatch_thread_groups(raw_count, self.state.stage_infos.cs.raw_wg_size); } } unsafe fn dispatch_indirect(&mut self, buffer: &super::Buffer, offset: wgt::BufferAddress) { let encoder = self.state.compute.as_ref().unwrap(); - encoder.dispatch_thread_groups_indirect(&buffer.raw, offset, self.state.raw_wg_size); + encoder.dispatch_thread_groups_indirect( + &buffer.raw, + offset, + self.state.stage_infos.cs.raw_wg_size, + ); } unsafe fn build_acceleration_structures<'a, T>( diff --git a/wgpu-hal/src/metal/device.rs b/wgpu-hal/src/metal/device.rs index 97878960a36..6474136f4d7 100644 --- a/wgpu-hal/src/metal/device.rs +++ b/wgpu-hal/src/metal/device.rs @@ -18,6 +18,11 @@ use metal::{ type DeviceResult = Result; +enum MetalGenericRenderPipelineDescriptor { + Standard(metal::RenderPipelineDescriptor), + Mesh(metal::MeshRenderPipelineDescriptor), +} + struct CompiledShader { library: metal::Library, function: metal::Function, @@ -1054,83 +1059,207 @@ impl crate::Device for super::Device { super::PipelineCache, >, ) -> Result { - let (desc_vertex_stage, desc_vertex_buffers) = match &desc.vertex_processor { - crate::VertexProcessor::Standard { - vertex_buffers, - vertex_stage, - } => (vertex_stage, *vertex_buffers), - crate::VertexProcessor::Mesh { .. } => unreachable!(), - }; - objc::rc::autoreleasepool(|| { - let descriptor = metal::RenderPipelineDescriptor::new(); - - let raw_triangle_fill_mode = match desc.primitive.polygon_mode { - wgt::PolygonMode::Fill => MTLTriangleFillMode::Fill, - wgt::PolygonMode::Line => MTLTriangleFillMode::Lines, - wgt::PolygonMode::Point => panic!( - "{:?} is not enabled for this backend", - wgt::Features::POLYGON_MODE_POINT - ), - }; - let (primitive_class, raw_primitive_type) = conv::map_primitive_topology(desc.primitive.topology); - // Vertex shader - let vs_info = { - let mut vertex_buffer_mappings = Vec::::new(); - for (i, vbl) in desc_vertex_buffers.iter().enumerate() { - let mut attributes = Vec::::new(); - for attribute in vbl.attributes.iter() { - attributes.push(naga::back::msl::AttributeMapping { - shader_location: attribute.shader_location, - offset: attribute.offset as u32, - format: convert_vertex_format_to_naga(attribute.format), - }); - } + let vs_info; + let ts_info; + let ms_info; + let descriptor = match desc.vertex_processor { + crate::VertexProcessor::Standard { + vertex_buffers, + ref vertex_stage, + } => { + let descriptor = metal::RenderPipelineDescriptor::new(); + ts_info = None; + ms_info = None; + vs_info = Some({ + let mut vertex_buffer_mappings = + Vec::::new(); + for (i, vbl) in vertex_buffers.iter().enumerate() { + let mut attributes = Vec::::new(); + for attribute in vbl.attributes.iter() { + attributes.push(naga::back::msl::AttributeMapping { + shader_location: attribute.shader_location, + offset: attribute.offset as u32, + format: convert_vertex_format_to_naga(attribute.format), + }); + } + + vertex_buffer_mappings.push(naga::back::msl::VertexBufferMapping { + id: self.shared.private_caps.max_vertex_buffers - 1 - i as u32, + stride: if vbl.array_stride > 0 { + vbl.array_stride.try_into().unwrap() + } else { + vbl.attributes + .iter() + .map(|attribute| attribute.offset + attribute.format.size()) + .max() + .unwrap_or(0) + .try_into() + .unwrap() + }, + indexed_by_vertex: (vbl.step_mode + == wgt::VertexStepMode::Vertex {}), + attributes, + }); + } - vertex_buffer_mappings.push(naga::back::msl::VertexBufferMapping { - id: self.shared.private_caps.max_vertex_buffers - 1 - i as u32, - stride: if vbl.array_stride > 0 { - vbl.array_stride.try_into().unwrap() - } else { - vbl.attributes - .iter() - .map(|attribute| attribute.offset + attribute.format.size()) - .max() - .unwrap_or(0) - .try_into() - .unwrap() - }, - indexed_by_vertex: (vbl.step_mode == wgt::VertexStepMode::Vertex {}), - attributes, + let vs = self.load_shader( + vertex_stage, + &vertex_buffer_mappings, + desc.layout, + primitive_class, + naga::ShaderStage::Vertex, + )?; + + descriptor.set_vertex_function(Some(&vs.function)); + if self.shared.private_caps.supports_mutability { + Self::set_buffers_mutability( + descriptor.vertex_buffers().unwrap(), + vs.immutable_buffer_mask, + ); + } + + super::PipelineStageInfo { + push_constants: desc.layout.push_constants_infos.vs, + sizes_slot: desc.layout.per_stage_map.vs.sizes_buffer, + sized_bindings: vs.sized_bindings, + vertex_buffer_mappings, + library: Some(vs.library), + raw_wg_size: Default::default(), + } }); - } + if desc.layout.total_counters.vs.buffers + (vertex_buffers.len() as u32) + > self.shared.private_caps.max_vertex_buffers + { + let msg = format!( + "pipeline needs too many buffers in the vertex stage: {} vertex and {} layout", + vertex_buffers.len(), + desc.layout.total_counters.vs.buffers + ); + return Err(crate::PipelineError::Linkage( + wgt::ShaderStages::VERTEX, + msg, + )); + } - let vs = self.load_shader( - desc_vertex_stage, - &vertex_buffer_mappings, - desc.layout, - primitive_class, - naga::ShaderStage::Vertex, - )?; - - descriptor.set_vertex_function(Some(&vs.function)); - if self.shared.private_caps.supports_mutability { - Self::set_buffers_mutability( - descriptor.vertex_buffers().unwrap(), - vs.immutable_buffer_mask, - ); - } + if !vertex_buffers.is_empty() { + let vertex_descriptor = metal::VertexDescriptor::new(); + for (i, vb) in vertex_buffers.iter().enumerate() { + let buffer_index = + self.shared.private_caps.max_vertex_buffers as u64 - 1 - i as u64; + let buffer_desc = + vertex_descriptor.layouts().object_at(buffer_index).unwrap(); + + // Metal expects the stride to be the actual size of the attributes. + // The semantics of array_stride == 0 can be achieved by setting + // the step function to constant and rate to 0. + if vb.array_stride == 0 { + let stride = vb + .attributes + .iter() + .map(|attribute| attribute.offset + attribute.format.size()) + .max() + .unwrap_or(0); + buffer_desc.set_stride(wgt::math::align_to(stride, 4)); + buffer_desc.set_step_function(MTLVertexStepFunction::Constant); + buffer_desc.set_step_rate(0); + } else { + buffer_desc.set_stride(vb.array_stride); + buffer_desc.set_step_function(conv::map_step_mode(vb.step_mode)); + } - super::PipelineStageInfo { - push_constants: desc.layout.push_constants_infos.vs, - sizes_slot: desc.layout.per_stage_map.vs.sizes_buffer, - sized_bindings: vs.sized_bindings, - vertex_buffer_mappings, - library: Some(vs.library), + for at in vb.attributes { + let attribute_desc = vertex_descriptor + .attributes() + .object_at(at.shader_location as u64) + .unwrap(); + attribute_desc.set_format(conv::map_vertex_format(at.format)); + attribute_desc.set_buffer_index(buffer_index); + attribute_desc.set_offset(at.offset); + } + } + descriptor.set_vertex_descriptor(Some(vertex_descriptor)); + } + todo!() } + crate::VertexProcessor::Mesh { + ref task_stage, + ref mesh_stage, + } => { + vs_info = None; + let descriptor = metal::MeshRenderPipelineDescriptor::new(); + if let Some(ref task_stage) = task_stage { + let ts = self.load_shader( + task_stage, + &[], + desc.layout, + primitive_class, + naga::ShaderStage::Task, + )?; + descriptor.set_mesh_function(Some(&ts.function)); + if self.shared.private_caps.supports_mutability { + Self::set_buffers_mutability( + descriptor.mesh_buffers().unwrap(), + ts.immutable_buffer_mask, + ); + } + ts_info = Some(super::PipelineStageInfo { + push_constants: desc.layout.push_constants_infos.ts, + sizes_slot: desc.layout.per_stage_map.ts.sizes_buffer, + sized_bindings: ts.sized_bindings, + vertex_buffer_mappings: vec![], + library: Some(ts.library), + raw_wg_size: Default::default(), + }); + } else { + ts_info = None; + } + { + let ms = self.load_shader( + mesh_stage, + &[], + desc.layout, + primitive_class, + naga::ShaderStage::Mesh, + )?; + descriptor.set_mesh_function(Some(&ms.function)); + if self.shared.private_caps.supports_mutability { + Self::set_buffers_mutability( + descriptor.mesh_buffers().unwrap(), + ms.immutable_buffer_mask, + ); + } + ms_info = Some(super::PipelineStageInfo { + push_constants: desc.layout.push_constants_infos.ms, + sizes_slot: desc.layout.per_stage_map.ms.sizes_buffer, + sized_bindings: ms.sized_bindings, + vertex_buffer_mappings: vec![], + library: Some(ms.library), + raw_wg_size: Default::default(), + }); + } + MetalGenericRenderPipelineDescriptor::Mesh(descriptor) + } + }; + macro_rules! descriptor_fn { + ($method:ident $( ( $($args:expr),* ) )? ) => { + match descriptor { + MetalGenericRenderPipelineDescriptor::Standard(ref inner) => inner.$method$(($($args),*))?, + MetalGenericRenderPipelineDescriptor::Mesh(ref inner) => inner.$method$(($($args),*))?, + } + }; + } + + let raw_triangle_fill_mode = match desc.primitive.polygon_mode { + wgt::PolygonMode::Fill => MTLTriangleFillMode::Fill, + wgt::PolygonMode::Line => MTLTriangleFillMode::Lines, + wgt::PolygonMode::Point => panic!( + "{:?} is not enabled for this backend", + wgt::Features::POLYGON_MODE_POINT + ), }; // Fragment shader @@ -1144,10 +1273,10 @@ impl crate::Device for super::Device { naga::ShaderStage::Fragment, )?; - descriptor.set_fragment_function(Some(&fs.function)); + descriptor_fn!(set_fragment_function(Some(&fs.function))); if self.shared.private_caps.supports_mutability { Self::set_buffers_mutability( - descriptor.fragment_buffers().unwrap(), + descriptor_fn!(fragment_buffers()).unwrap(), fs.immutable_buffer_mask, ); } @@ -1158,20 +1287,25 @@ impl crate::Device for super::Device { sized_bindings: fs.sized_bindings, vertex_buffer_mappings: vec![], library: Some(fs.library), + raw_wg_size: Default::default(), }) } None => { // TODO: This is a workaround for what appears to be a Metal validation bug // A pixel format is required even though no attachments are provided if desc.color_targets.is_empty() && desc.depth_stencil.is_none() { - descriptor.set_depth_attachment_pixel_format(MTLPixelFormat::Depth32Float); + descriptor_fn!(set_depth_attachment_pixel_format( + MTLPixelFormat::Depth32Float + )); } None } }; for (i, ct) in desc.color_targets.iter().enumerate() { - let at_descriptor = descriptor.color_attachments().object_at(i as u64).unwrap(); + let at_descriptor = descriptor_fn!(color_attachments()) + .object_at(i as u64) + .unwrap(); let ct = if let Some(color_target) = ct.as_ref() { color_target } else { @@ -1203,10 +1337,10 @@ impl crate::Device for super::Device { let raw_format = self.shared.private_caps.map_format(ds.format); let aspects = crate::FormatAspects::from(ds.format); if aspects.contains(crate::FormatAspects::DEPTH) { - descriptor.set_depth_attachment_pixel_format(raw_format); + descriptor_fn!(set_depth_attachment_pixel_format(raw_format)); } if aspects.contains(crate::FormatAspects::STENCIL) { - descriptor.set_stencil_attachment_pixel_format(raw_format); + descriptor_fn!(set_stencil_attachment_pixel_format(raw_format)); } let ds_descriptor = create_depth_stencil_desc(ds); @@ -1220,90 +1354,61 @@ impl crate::Device for super::Device { None => None, }; - if desc.layout.total_counters.vs.buffers + (desc_vertex_buffers.len() as u32) - > self.shared.private_caps.max_vertex_buffers - { - let msg = format!( - "pipeline needs too many buffers in the vertex stage: {} vertex and {} layout", - desc_vertex_buffers.len(), - desc.layout.total_counters.vs.buffers - ); - return Err(crate::PipelineError::Linkage( - wgt::ShaderStages::VERTEX, - msg, - )); - } - - if !desc_vertex_buffers.is_empty() { - let vertex_descriptor = metal::VertexDescriptor::new(); - for (i, vb) in desc_vertex_buffers.iter().enumerate() { - let buffer_index = - self.shared.private_caps.max_vertex_buffers as u64 - 1 - i as u64; - let buffer_desc = vertex_descriptor.layouts().object_at(buffer_index).unwrap(); - - // Metal expects the stride to be the actual size of the attributes. - // The semantics of array_stride == 0 can be achieved by setting - // the step function to constant and rate to 0. - if vb.array_stride == 0 { - let stride = vb - .attributes - .iter() - .map(|attribute| attribute.offset + attribute.format.size()) - .max() - .unwrap_or(0); - buffer_desc.set_stride(wgt::math::align_to(stride, 4)); - buffer_desc.set_step_function(MTLVertexStepFunction::Constant); - buffer_desc.set_step_rate(0); - } else { - buffer_desc.set_stride(vb.array_stride); - buffer_desc.set_step_function(conv::map_step_mode(vb.step_mode)); + if desc.multisample.count != 1 { + //TODO: handle sample mask + match descriptor { + MetalGenericRenderPipelineDescriptor::Standard(ref inner) => { + inner.set_sample_count(desc.multisample.count as u64); } - - for at in vb.attributes { - let attribute_desc = vertex_descriptor - .attributes() - .object_at(at.shader_location as u64) - .unwrap(); - attribute_desc.set_format(conv::map_vertex_format(at.format)); - attribute_desc.set_buffer_index(buffer_index); - attribute_desc.set_offset(at.offset); + MetalGenericRenderPipelineDescriptor::Mesh(ref inner) => { + inner.set_raster_sample_count(desc.multisample.count as u64); } } - descriptor.set_vertex_descriptor(Some(vertex_descriptor)); - } - - if desc.multisample.count != 1 { - //TODO: handle sample mask - descriptor.set_sample_count(desc.multisample.count as u64); - descriptor - .set_alpha_to_coverage_enabled(desc.multisample.alpha_to_coverage_enabled); + descriptor_fn!(set_alpha_to_coverage_enabled( + desc.multisample.alpha_to_coverage_enabled + )); //descriptor.set_alpha_to_one_enabled(desc.multisample.alpha_to_one_enabled); } if let Some(name) = desc.label { - descriptor.set_label(name); + descriptor_fn!(set_label(name)); } - let raw = self - .shared - .device - .lock() - .new_render_pipeline_state(&descriptor) - .map_err(|e| { - crate::PipelineError::Linkage( - wgt::ShaderStages::VERTEX | wgt::ShaderStages::FRAGMENT, - format!("new_render_pipeline_state: {e:?}"), - ) - })?; + let raw = match descriptor { + MetalGenericRenderPipelineDescriptor::Standard(d) => self + .shared + .device + .lock() + .new_render_pipeline_state(&d) + .map_err(|e| { + crate::PipelineError::Linkage( + wgt::ShaderStages::VERTEX | wgt::ShaderStages::FRAGMENT, + format!("new_render_pipeline_state: {e:?}"), + ) + })?, + MetalGenericRenderPipelineDescriptor::Mesh(d) => self + .shared + .device + .lock() + .new_mesh_render_pipeline_state(&d) + .map_err(|e| { + crate::PipelineError::Linkage( + wgt::ShaderStages::TASK + | wgt::ShaderStages::MESH + | wgt::ShaderStages::FRAGMENT, + format!("new_render_pipeline_state: {e:?}"), + ) + })?, + }; self.counters.render_pipelines.add(1); Ok(super::RenderPipeline { raw, - vs_info: Some(vs_info), + vs_info, fs_info, - ts_info: None, - ms_info: None, + ts_info, + ms_info, raw_primitive_type, raw_triangle_fill_mode, raw_front_winding: conv::map_winding(desc.primitive.front_face), @@ -1376,6 +1481,7 @@ impl crate::Device for super::Device { sizes_slot: desc.layout.per_stage_map.cs.sizes_buffer, sized_bindings: cs.sized_bindings, vertex_buffer_mappings: vec![], + raw_wg_size: cs.wg_size, }; if let Some(name) = desc.label { @@ -1399,7 +1505,6 @@ impl crate::Device for super::Device { Ok(super::ComputePipeline { raw, cs_info, - work_group_size: cs.wg_size, work_group_memory_sizes: cs.wg_memory_sizes, }) }) diff --git a/wgpu-hal/src/metal/mod.rs b/wgpu-hal/src/metal/mod.rs index ec4ae11cdef..a9d9e19b57b 100644 --- a/wgpu-hal/src/metal/mod.rs +++ b/wgpu-hal/src/metal/mod.rs @@ -836,6 +836,9 @@ struct PipelineStageInfo { /// Info on all bound vertex buffers. vertex_buffer_mappings: Vec, + + /// The workgroup size for compute, task or mesh stages + raw_wg_size: MTLSize, } impl PipelineStageInfo { @@ -881,7 +884,6 @@ impl crate::DynRenderPipeline for RenderPipeline {} pub struct ComputePipeline { raw: metal::ComputePipelineState, cs_info: PipelineStageInfo, - work_group_size: MTLSize, work_group_memory_sizes: Vec, } @@ -956,7 +958,6 @@ struct CommandState { compute: Option, raw_primitive_type: MTLPrimitiveType, index: Option, - raw_wg_size: MTLSize, stage_infos: MultiStageData, /// Sizes of currently bound [`wgt::BufferBindingType::Storage`] buffers. From 3d36680bca124a3d61a7e426553c71a6bdb4eab6 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Sun, 24 Aug 2025 00:13:03 -0500 Subject: [PATCH 03/16] Oops --- wgpu-hal/src/metal/device.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wgpu-hal/src/metal/device.rs b/wgpu-hal/src/metal/device.rs index 6474136f4d7..4f1154c42c3 100644 --- a/wgpu-hal/src/metal/device.rs +++ b/wgpu-hal/src/metal/device.rs @@ -1183,7 +1183,7 @@ impl crate::Device for super::Device { } descriptor.set_vertex_descriptor(Some(vertex_descriptor)); } - todo!() + MetalGenericRenderPipelineDescriptor::Standard(descriptor) } crate::VertexProcessor::Mesh { ref task_stage, From c9c39fd4ab74d7d0516c3e556fb86e9935dde031 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Sun, 24 Aug 2025 00:22:57 -0500 Subject: [PATCH 04/16] Another refactor --- wgpu-hal/src/metal/adapter.rs | 5 ++--- wgpu-hal/src/metal/command.rs | 6 ++++-- wgpu-hal/src/metal/device.rs | 11 ++++++----- wgpu-hal/src/metal/mod.rs | 4 +++- 4 files changed, 15 insertions(+), 11 deletions(-) diff --git a/wgpu-hal/src/metal/adapter.rs b/wgpu-hal/src/metal/adapter.rs index 9517f0b4dd6..d298ee7da15 100644 --- a/wgpu-hal/src/metal/adapter.rs +++ b/wgpu-hal/src/metal/adapter.rs @@ -606,8 +606,6 @@ impl super::PrivateCapabilities { } let argument_buffers = device.argument_buffers_support(); - let mesh_shaders = device.supports_family(MTLGPUFamily::Apple7) - || device.supports_family(MTLGPUFamily::Mac2); Self { family_check, @@ -904,7 +902,8 @@ impl super::PrivateCapabilities { && (device.supports_family(MTLGPUFamily::Apple7) || device.supports_family(MTLGPUFamily::Mac2)), supports_shared_event: version.at_least((10, 14), (12, 0), os_is_mac), - mesh_shaders, + mesh_shaders: device.supports_family(MTLGPUFamily::Apple7) + || device.supports_family(MTLGPUFamily::Mac2), } } diff --git a/wgpu-hal/src/metal/command.rs b/wgpu-hal/src/metal/command.rs index 37beb41a9a3..db282a8d91e 100644 --- a/wgpu-hal/src/metal/command.rs +++ b/wgpu-hal/src/metal/command.rs @@ -1335,14 +1335,16 @@ impl crate::CommandEncoder for super::CommandEncoder { } // update the threadgroup memory sizes - while self.state.work_group_memory_sizes.len() < pipeline.work_group_memory_sizes.len() { + while self.state.work_group_memory_sizes.len() + < pipeline.cs_info.work_group_memory_sizes.len() + { self.state.work_group_memory_sizes.push(0); } for (index, (cur_size, pipeline_size)) in self .state .work_group_memory_sizes .iter_mut() - .zip(pipeline.work_group_memory_sizes.iter()) + .zip(pipeline.cs_info.work_group_memory_sizes.iter()) .enumerate() { let size = pipeline_size.next_multiple_of(16); diff --git a/wgpu-hal/src/metal/device.rs b/wgpu-hal/src/metal/device.rs index 4f1154c42c3..ee1a74b2131 100644 --- a/wgpu-hal/src/metal/device.rs +++ b/wgpu-hal/src/metal/device.rs @@ -1129,6 +1129,7 @@ impl crate::Device for super::Device { vertex_buffer_mappings, library: Some(vs.library), raw_wg_size: Default::default(), + work_group_memory_sizes: vec![], } }); if desc.layout.total_counters.vs.buffers + (vertex_buffers.len() as u32) @@ -1213,6 +1214,7 @@ impl crate::Device for super::Device { vertex_buffer_mappings: vec![], library: Some(ts.library), raw_wg_size: Default::default(), + work_group_memory_sizes: vec![], }); } else { ts_info = None; @@ -1239,6 +1241,7 @@ impl crate::Device for super::Device { vertex_buffer_mappings: vec![], library: Some(ms.library), raw_wg_size: Default::default(), + work_group_memory_sizes: vec![], }); } MetalGenericRenderPipelineDescriptor::Mesh(descriptor) @@ -1288,6 +1291,7 @@ impl crate::Device for super::Device { vertex_buffer_mappings: vec![], library: Some(fs.library), raw_wg_size: Default::default(), + work_group_memory_sizes: vec![], }) } None => { @@ -1482,6 +1486,7 @@ impl crate::Device for super::Device { sized_bindings: cs.sized_bindings, vertex_buffer_mappings: vec![], raw_wg_size: cs.wg_size, + work_group_memory_sizes: cs.wg_memory_sizes, }; if let Some(name) = desc.label { @@ -1502,11 +1507,7 @@ impl crate::Device for super::Device { self.counters.compute_pipelines.add(1); - Ok(super::ComputePipeline { - raw, - cs_info, - work_group_memory_sizes: cs.wg_memory_sizes, - }) + Ok(super::ComputePipeline { raw, cs_info }) }) } diff --git a/wgpu-hal/src/metal/mod.rs b/wgpu-hal/src/metal/mod.rs index a9d9e19b57b..c2d2a80a214 100644 --- a/wgpu-hal/src/metal/mod.rs +++ b/wgpu-hal/src/metal/mod.rs @@ -839,6 +839,9 @@ struct PipelineStageInfo { /// The workgroup size for compute, task or mesh stages raw_wg_size: MTLSize, + + /// The workgroup memory sizes for compute task or mesh stages + work_group_memory_sizes: Vec, } impl PipelineStageInfo { @@ -884,7 +887,6 @@ impl crate::DynRenderPipeline for RenderPipeline {} pub struct ComputePipeline { raw: metal::ComputePipelineState, cs_info: PipelineStageInfo, - work_group_memory_sizes: Vec, } unsafe impl Send for ComputePipeline {} From fb330288f734898ac0e6a000ba32e4f3d23e3b4b Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Sun, 24 Aug 2025 01:49:43 -0500 Subject: [PATCH 05/16] Another slight refactor --- wgpu-hal/src/metal/command.rs | 8 ++++---- wgpu-hal/src/metal/mod.rs | 4 +++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/wgpu-hal/src/metal/command.rs b/wgpu-hal/src/metal/command.rs index db282a8d91e..a91035b642f 100644 --- a/wgpu-hal/src/metal/command.rs +++ b/wgpu-hal/src/metal/command.rs @@ -24,7 +24,6 @@ impl Default for super::CommandState { stage_infos: Default::default(), storage_buffer_length_map: Default::default(), vertex_buffer_size_map: Default::default(), - work_group_memory_sizes: Vec::new(), push_constants: Vec::new(), pending_timer_queries: Vec::new(), } @@ -149,7 +148,6 @@ impl super::CommandState { self.stage_infos.vs.clear(); self.stage_infos.fs.clear(); self.stage_infos.cs.clear(); - self.work_group_memory_sizes.clear(); self.push_constants.clear(); } @@ -1335,13 +1333,15 @@ impl crate::CommandEncoder for super::CommandEncoder { } // update the threadgroup memory sizes - while self.state.work_group_memory_sizes.len() + while self.state.stage_infos.cs.work_group_memory_sizes.len() < pipeline.cs_info.work_group_memory_sizes.len() { - self.state.work_group_memory_sizes.push(0); + self.state.stage_infos.cs.work_group_memory_sizes.push(0); } for (index, (cur_size, pipeline_size)) in self .state + .stage_infos + .cs .work_group_memory_sizes .iter_mut() .zip(pipeline.cs_info.work_group_memory_sizes.iter()) diff --git a/wgpu-hal/src/metal/mod.rs b/wgpu-hal/src/metal/mod.rs index c2d2a80a214..c4d9992e7db 100644 --- a/wgpu-hal/src/metal/mod.rs +++ b/wgpu-hal/src/metal/mod.rs @@ -850,6 +850,9 @@ impl PipelineStageInfo { self.sizes_slot = None; self.sized_bindings.clear(); self.vertex_buffer_mappings.clear(); + self.library = None; + self.work_group_memory_sizes.clear(); + self.raw_wg_size = Default::default(); } fn assign_from(&mut self, other: &Self) { @@ -985,7 +988,6 @@ struct CommandState { vertex_buffer_size_map: FastHashMap, - work_group_memory_sizes: Vec, push_constants: Vec, /// Timer query that should be executed when the next pass starts. From ece1ea10c1c9c3cebbc2db4ec1742e5aac1eb289 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Sun, 24 Aug 2025 01:58:22 -0500 Subject: [PATCH 06/16] Another slight refactor --- wgpu-hal/src/metal/command.rs | 2 ++ wgpu-hal/src/metal/mod.rs | 5 +++++ 2 files changed, 7 insertions(+) diff --git a/wgpu-hal/src/metal/command.rs b/wgpu-hal/src/metal/command.rs index a91035b642f..a83540a9a37 100644 --- a/wgpu-hal/src/metal/command.rs +++ b/wgpu-hal/src/metal/command.rs @@ -148,6 +148,8 @@ impl super::CommandState { self.stage_infos.vs.clear(); self.stage_infos.fs.clear(); self.stage_infos.cs.clear(); + self.stage_infos.ts.clear(); + self.stage_infos.ms.clear(); self.push_constants.clear(); } diff --git a/wgpu-hal/src/metal/mod.rs b/wgpu-hal/src/metal/mod.rs index c4d9992e7db..1e7b5281240 100644 --- a/wgpu-hal/src/metal/mod.rs +++ b/wgpu-hal/src/metal/mod.rs @@ -863,6 +863,11 @@ impl PipelineStageInfo { self.vertex_buffer_mappings.clear(); self.vertex_buffer_mappings .extend_from_slice(&other.vertex_buffer_mappings); + self.library = Some(other.library.as_ref().unwrap().clone()); + self.raw_wg_size = other.raw_wg_size; + self.work_group_memory_sizes.clear(); + self.work_group_memory_sizes + .extend_from_slice(&other.work_group_memory_sizes); } } From 47c187b40ed14a08112ad0221dff9295c6190259 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Sun, 24 Aug 2025 01:59:49 -0500 Subject: [PATCH 07/16] Fixed it --- wgpu-hal/src/metal/command.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wgpu-hal/src/metal/command.rs b/wgpu-hal/src/metal/command.rs index a83540a9a37..542287983e9 100644 --- a/wgpu-hal/src/metal/command.rs +++ b/wgpu-hal/src/metal/command.rs @@ -915,11 +915,11 @@ impl crate::CommandEncoder for super::CommandEncoder { } match pipeline.ts_info { Some(ref info) => self.state.stage_infos.ts.assign_from(info), - None => self.state.stage_infos.vs.clear(), + None => self.state.stage_infos.ts.clear(), } match pipeline.ms_info { Some(ref info) => self.state.stage_infos.ms.assign_from(info), - None => self.state.stage_infos.fs.clear(), + None => self.state.stage_infos.ms.clear(), } let encoder = self.state.render.as_ref().unwrap(); From 8bc63b662a542971d6b8f3c19284e35fc3417592 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Sun, 24 Aug 2025 02:33:22 -0500 Subject: [PATCH 08/16] Worked a little more on trying to add it to example --- examples/features/src/mesh_shader/mod.rs | 32 ++++++-- .../features/src/mesh_shader/shader.metal | 74 +++++++++++++++++++ wgpu-types/src/lib.rs | 4 +- 3 files changed, 102 insertions(+), 8 deletions(-) create mode 100644 examples/features/src/mesh_shader/shader.metal diff --git a/examples/features/src/mesh_shader/mod.rs b/examples/features/src/mesh_shader/mod.rs index 675150f5106..e21e7ae2c95 100644 --- a/examples/features/src/mesh_shader/mod.rs +++ b/examples/features/src/mesh_shader/mod.rs @@ -33,13 +33,25 @@ fn compile_glsl( } } +fn compile_msl(device: &wgpu::Device, entry: &str) -> wgpu::ShaderModule { + unsafe { + device.create_shader_module_passthrough(wgpu::ShaderModuleDescriptorPassthrough { + entry_point: entry.to_owned(), + label: None, + msl: Some(std::borrow::Cow::Borrowed(include_str!("shader.metal"))), + num_workgroups: (1, 1, 1), + ..Default::default() + }) + } +} + pub struct Example { pipeline: wgpu::RenderPipeline, } impl crate::framework::Example for Example { fn init( config: &wgpu::SurfaceConfiguration, - _adapter: &wgpu::Adapter, + adapter: &wgpu::Adapter, device: &wgpu::Device, _queue: &wgpu::Queue, ) -> Self { @@ -48,11 +60,19 @@ impl crate::framework::Example for Example { bind_group_layouts: &[], push_constant_ranges: &[], }); - let (ts, ms, fs) = ( - compile_glsl(device, include_bytes!("shader.task"), "task"), - compile_glsl(device, include_bytes!("shader.mesh"), "mesh"), - compile_glsl(device, include_bytes!("shader.frag"), "frag"), - ); + let (ts, ms, fs) = if adapter.get_info().backend == wgpu::Backend::Metal { + ( + compile_msl(device, "taskShader"), + compile_msl(device, "meshShader"), + compile_msl(device, "fragShader"), + ) + } else { + ( + compile_glsl(device, include_bytes!("shader.task"), "task"), + compile_glsl(device, include_bytes!("shader.mesh"), "mesh"), + compile_glsl(device, include_bytes!("shader.frag"), "frag"), + ) + }; let pipeline = device.create_mesh_pipeline(&wgpu::MeshPipelineDescriptor { label: None, layout: Some(&pipeline_layout), diff --git a/examples/features/src/mesh_shader/shader.metal b/examples/features/src/mesh_shader/shader.metal new file mode 100644 index 00000000000..0a563132a19 --- /dev/null +++ b/examples/features/src/mesh_shader/shader.metal @@ -0,0 +1,74 @@ +using namespace metal; + +struct OutVertex { + float4 Position [[position]]; + float4 Color; +}; + +struct OutPrimitive { + float4 ColorMask [[flat]]; + bool CullPrimitive; +}; + +struct InVertex { + float4 Color; +}; + +struct InPrimitive { + float4 ColorMask [[flat]]; +}; + +struct PayloadData { + float4 ColorMask; + bool Visible; +}; + +using Meshlet = metal::mesh; + + +constant float4 positions[3] = { + float4(0.0, 1.0, 0.0, 1.0), + float4(-1.0, -1.0, 0.0, 1.0), + float4(1.0, -1.0, 0.0, 1.0) +}; + +constant float4 colors[3] = { + float4(0.0, 1.0, 0.0, 1.0), + float4(0.0, 0.0, 1.0, 1.0), + float4(1.0, 0.0, 0.0, 1.0) +}; + + +[[object]] +void taskShader(uint3 tid [[thread_position_in_grid]], object_data PayloadData &outPayload [[payload]], grid_properties grid) { + outPayload.ColorMask = float4(1.0, 1.0, 0.0, 1.0); + outPayload.Visible = true; + grid.set_threadgroups_per_grid(uint3(3, 1, 1)); +} + +[[mesh, topology(triangle)]] +void meshShader( + object_data PayloadData const& payload [[payload]], + Meshlet out, +) +{ + out.set_primitive_count(1); + + for(int i = 0;i < 3;i++) { + OutVertex vert; + vert.Position = positions[i]; + vert.Color = colors[i] * payload.ColorMask; + mesh.set_vertex(i, vert); + out.set_index(i, i); + } + + triangles[0] = uint3(0, 1, 2); + OutPrimitive prim; + prim.ColorMask = float4(1.0, 0.0, 0.0, 1.0); + prim.CullPrimitive = !payload.Visible; + out.set_primitive(0, prim); +} + +fragment float4 fragShader(OutVertex inVertex [[stage_in]], OutPrimitive inPrimitive [[stage_in]]) { + return inVertex.Color * inPrimitive.ColorMask; +} diff --git a/wgpu-types/src/lib.rs b/wgpu-types/src/lib.rs index ea2a09eb62a..828136a690c 100644 --- a/wgpu-types/src/lib.rs +++ b/wgpu-types/src/lib.rs @@ -979,8 +979,8 @@ impl Limits { // Literally just made this up as 256^2 or 2^16. // My GPU supports 2^22, and compute shaders don't have this kind of limit. // This very likely is never a real limiter - max_task_workgroup_total_count: 65536, - max_task_workgroups_per_dimension: 256, + max_task_workgroup_total_count: 1024, + max_task_workgroups_per_dimension: 1024, // llvmpipe reports 0 multiview count, which just means no multiview is allowed max_mesh_multiview_count: 0, // llvmpipe once again requires this to be 8. An RTX 3060 supports well over 1024. From 55d6bf3b3ab94c3ea16d09856bc393d98a9b67ed Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Sun, 24 Aug 2025 02:40:34 -0500 Subject: [PATCH 09/16] Fixed metal shader --- examples/features/src/mesh_shader/shader.metal | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/examples/features/src/mesh_shader/shader.metal b/examples/features/src/mesh_shader/shader.metal index 0a563132a19..65edc83e442 100644 --- a/examples/features/src/mesh_shader/shader.metal +++ b/examples/features/src/mesh_shader/shader.metal @@ -18,6 +18,11 @@ struct InPrimitive { float4 ColorMask [[flat]]; }; +struct FragmentIn { + InVertex vert; + InPrimitive prim; +}; + struct PayloadData { float4 ColorMask; bool Visible; @@ -40,16 +45,16 @@ constant float4 colors[3] = { [[object]] -void taskShader(uint3 tid [[thread_position_in_grid]], object_data PayloadData &outPayload [[payload]], grid_properties grid) { +void taskShader(uint3 tid [[thread_position_in_grid]], object_data PayloadData &outPayload [[payload]], mesh_grid_properties grid) { outPayload.ColorMask = float4(1.0, 1.0, 0.0, 1.0); outPayload.Visible = true; grid.set_threadgroups_per_grid(uint3(3, 1, 1)); } -[[mesh, topology(triangle)]] +[[mesh]] void meshShader( object_data PayloadData const& payload [[payload]], - Meshlet out, + Meshlet out ) { out.set_primitive_count(1); @@ -58,17 +63,16 @@ void meshShader( OutVertex vert; vert.Position = positions[i]; vert.Color = colors[i] * payload.ColorMask; - mesh.set_vertex(i, vert); + out.set_vertex(i, vert); out.set_index(i, i); } - triangles[0] = uint3(0, 1, 2); OutPrimitive prim; prim.ColorMask = float4(1.0, 0.0, 0.0, 1.0); prim.CullPrimitive = !payload.Visible; out.set_primitive(0, prim); } -fragment float4 fragShader(OutVertex inVertex [[stage_in]], OutPrimitive inPrimitive [[stage_in]]) { - return inVertex.Color * inPrimitive.ColorMask; +fragment float4 fragShader(FragmentIn data [[stage_in]]) { + return data.vert.Color * data.prim.ColorMask; } From edfd494cbd71ffa8b9b2dd82583e91a9e5c68a18 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Sun, 24 Aug 2025 02:56:17 -0500 Subject: [PATCH 10/16] Fixed some passthrough stuff, now it runs (uggh) --- wgpu-hal/src/metal/device.rs | 326 ++++++++++++++++++----------------- wgpu-hal/src/metal/mod.rs | 3 +- 2 files changed, 174 insertions(+), 155 deletions(-) diff --git a/wgpu-hal/src/metal/device.rs b/wgpu-hal/src/metal/device.rs index ee1a74b2131..3a48c9e8ead 100644 --- a/wgpu-hal/src/metal/device.rs +++ b/wgpu-hal/src/metal/device.rs @@ -133,176 +133,194 @@ impl super::Device { primitive_class: MTLPrimitiveTopologyClass, naga_stage: naga::ShaderStage, ) -> Result { - let naga_shader = if let ShaderModuleSource::Naga(naga) = &stage.module.source { - naga - } else { - panic!("load_shader required a naga shader"); - }; - let stage_bit = map_naga_stage(naga_stage); - let (module, module_info) = naga::back::pipeline_constants::process_overrides( - &naga_shader.module, - &naga_shader.info, - Some((naga_stage, stage.entry_point)), - stage.constants, - ) - .map_err(|e| crate::PipelineError::PipelineConstants(stage_bit, format!("MSL: {e:?}")))?; - - let ep_resources = &layout.per_stage_map[naga_stage]; - - let bounds_check_policy = if stage.module.bounds_checks.bounds_checks { - naga::proc::BoundsCheckPolicy::Restrict - } else { - naga::proc::BoundsCheckPolicy::Unchecked - }; + match stage.module.source { + ShaderModuleSource::Naga(ref naga_shader) => { + let stage_bit = map_naga_stage(naga_stage); + let (module, module_info) = naga::back::pipeline_constants::process_overrides( + &naga_shader.module, + &naga_shader.info, + Some((naga_stage, stage.entry_point)), + stage.constants, + ) + .map_err(|e| { + crate::PipelineError::PipelineConstants(stage_bit, format!("MSL: {e:?}")) + })?; - let options = naga::back::msl::Options { - lang_version: match self.shared.private_caps.msl_version { - MTLLanguageVersion::V1_0 => (1, 0), - MTLLanguageVersion::V1_1 => (1, 1), - MTLLanguageVersion::V1_2 => (1, 2), - MTLLanguageVersion::V2_0 => (2, 0), - MTLLanguageVersion::V2_1 => (2, 1), - MTLLanguageVersion::V2_2 => (2, 2), - MTLLanguageVersion::V2_3 => (2, 3), - MTLLanguageVersion::V2_4 => (2, 4), - MTLLanguageVersion::V3_0 => (3, 0), - MTLLanguageVersion::V3_1 => (3, 1), - }, - inline_samplers: Default::default(), - spirv_cross_compatibility: false, - fake_missing_bindings: false, - per_entry_point_map: naga::back::msl::EntryPointResourceMap::from([( - stage.entry_point.to_owned(), - ep_resources.clone(), - )]), - bounds_check_policies: naga::proc::BoundsCheckPolicies { - index: bounds_check_policy, - buffer: bounds_check_policy, - image_load: bounds_check_policy, - // TODO: support bounds checks on binding arrays - binding_array: naga::proc::BoundsCheckPolicy::Unchecked, - }, - zero_initialize_workgroup_memory: stage.zero_initialize_workgroup_memory, - force_loop_bounding: stage.module.bounds_checks.force_loop_bounding, - }; + let ep_resources = &layout.per_stage_map[naga_stage]; - let pipeline_options = naga::back::msl::PipelineOptions { - entry_point: Some((naga_stage, stage.entry_point.to_owned())), - allow_and_force_point_size: match primitive_class { - MTLPrimitiveTopologyClass::Point => true, - _ => false, - }, - vertex_pulling_transform: true, - vertex_buffer_mappings: vertex_buffer_mappings.to_vec(), - }; + let bounds_check_policy = if stage.module.bounds_checks.bounds_checks { + naga::proc::BoundsCheckPolicy::Restrict + } else { + naga::proc::BoundsCheckPolicy::Unchecked + }; - let (source, info) = - naga::back::msl::write_string(&module, &module_info, &options, &pipeline_options) - .map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("MSL: {e:?}")))?; + let options = naga::back::msl::Options { + lang_version: match self.shared.private_caps.msl_version { + MTLLanguageVersion::V1_0 => (1, 0), + MTLLanguageVersion::V1_1 => (1, 1), + MTLLanguageVersion::V1_2 => (1, 2), + MTLLanguageVersion::V2_0 => (2, 0), + MTLLanguageVersion::V2_1 => (2, 1), + MTLLanguageVersion::V2_2 => (2, 2), + MTLLanguageVersion::V2_3 => (2, 3), + MTLLanguageVersion::V2_4 => (2, 4), + MTLLanguageVersion::V3_0 => (3, 0), + MTLLanguageVersion::V3_1 => (3, 1), + }, + inline_samplers: Default::default(), + spirv_cross_compatibility: false, + fake_missing_bindings: false, + per_entry_point_map: naga::back::msl::EntryPointResourceMap::from([( + stage.entry_point.to_owned(), + ep_resources.clone(), + )]), + bounds_check_policies: naga::proc::BoundsCheckPolicies { + index: bounds_check_policy, + buffer: bounds_check_policy, + image_load: bounds_check_policy, + // TODO: support bounds checks on binding arrays + binding_array: naga::proc::BoundsCheckPolicy::Unchecked, + }, + zero_initialize_workgroup_memory: stage.zero_initialize_workgroup_memory, + force_loop_bounding: stage.module.bounds_checks.force_loop_bounding, + }; - log::debug!( - "Naga generated shader for entry point '{}' and stage {:?}\n{}", - stage.entry_point, - naga_stage, - &source - ); + let pipeline_options = naga::back::msl::PipelineOptions { + entry_point: Some((naga_stage, stage.entry_point.to_owned())), + allow_and_force_point_size: match primitive_class { + MTLPrimitiveTopologyClass::Point => true, + _ => false, + }, + vertex_pulling_transform: true, + vertex_buffer_mappings: vertex_buffer_mappings.to_vec(), + }; - let options = metal::CompileOptions::new(); - options.set_language_version(self.shared.private_caps.msl_version); + let (source, info) = naga::back::msl::write_string( + &module, + &module_info, + &options, + &pipeline_options, + ) + .map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("MSL: {e:?}")))?; - if self.shared.private_caps.supports_preserve_invariance { - options.set_preserve_invariance(true); - } + log::debug!( + "Naga generated shader for entry point '{}' and stage {:?}\n{}", + stage.entry_point, + naga_stage, + &source + ); - let library = self - .shared - .device - .lock() - .new_library_with_source(source.as_ref(), &options) - .map_err(|err| { - log::warn!("Naga generated shader:\n{source}"); - crate::PipelineError::Linkage(stage_bit, format!("Metal: {err}")) - })?; - - let ep_index = module - .entry_points - .iter() - .position(|ep| ep.stage == naga_stage && ep.name == stage.entry_point) - .ok_or(crate::PipelineError::EntryPoint(naga_stage))?; - let ep = &module.entry_points[ep_index]; - let translated_ep_name = info.entry_point_names[0] - .as_ref() - .map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("{e}")))?; - - let wg_size = MTLSize { - width: ep.workgroup_size[0] as _, - height: ep.workgroup_size[1] as _, - depth: ep.workgroup_size[2] as _, - }; + let options = metal::CompileOptions::new(); + options.set_language_version(self.shared.private_caps.msl_version); - let function = library - .get_function(translated_ep_name, None) - .map_err(|e| { - log::error!("get_function: {e:?}"); - crate::PipelineError::EntryPoint(naga_stage) - })?; - - // collect sizes indices, immutable buffers, and work group memory sizes - let ep_info = &module_info.get_entry_point(ep_index); - let mut wg_memory_sizes = Vec::new(); - let mut sized_bindings = Vec::new(); - let mut immutable_buffer_mask = 0; - for (var_handle, var) in module.global_variables.iter() { - match var.space { - naga::AddressSpace::WorkGroup => { - if !ep_info[var_handle].is_empty() { - let size = module.types[var.ty].inner.size(module.to_ctx()); - wg_memory_sizes.push(size); - } + if self.shared.private_caps.supports_preserve_invariance { + options.set_preserve_invariance(true); } - naga::AddressSpace::Uniform | naga::AddressSpace::Storage { .. } => { - let br = match var.binding { - Some(br) => br, - None => continue, - }; - let storage_access_store = match var.space { - naga::AddressSpace::Storage { access } => { - access.contains(naga::StorageAccess::STORE) + + let library = self + .shared + .device + .lock() + .new_library_with_source(source.as_ref(), &options) + .map_err(|err| { + log::warn!("Naga generated shader:\n{source}"); + crate::PipelineError::Linkage(stage_bit, format!("Metal: {err}")) + })?; + + let ep_index = module + .entry_points + .iter() + .position(|ep| ep.stage == naga_stage && ep.name == stage.entry_point) + .ok_or(crate::PipelineError::EntryPoint(naga_stage))?; + let ep = &module.entry_points[ep_index]; + let translated_ep_name = info.entry_point_names[0] + .as_ref() + .map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("{e}")))?; + + let wg_size = MTLSize { + width: ep.workgroup_size[0] as _, + height: ep.workgroup_size[1] as _, + depth: ep.workgroup_size[2] as _, + }; + + let function = library + .get_function(translated_ep_name, None) + .map_err(|e| { + log::error!("get_function: {e:?}"); + crate::PipelineError::EntryPoint(naga_stage) + })?; + + // collect sizes indices, immutable buffers, and work group memory sizes + let ep_info = &module_info.get_entry_point(ep_index); + let mut wg_memory_sizes = Vec::new(); + let mut sized_bindings = Vec::new(); + let mut immutable_buffer_mask = 0; + for (var_handle, var) in module.global_variables.iter() { + match var.space { + naga::AddressSpace::WorkGroup => { + if !ep_info[var_handle].is_empty() { + let size = module.types[var.ty].inner.size(module.to_ctx()); + wg_memory_sizes.push(size); + } } - _ => false, - }; + naga::AddressSpace::Uniform | naga::AddressSpace::Storage { .. } => { + let br = match var.binding { + Some(br) => br, + None => continue, + }; + let storage_access_store = match var.space { + naga::AddressSpace::Storage { access } => { + access.contains(naga::StorageAccess::STORE) + } + _ => false, + }; - // check for an immutable buffer - if !ep_info[var_handle].is_empty() && !storage_access_store { - let slot = ep_resources.resources[&br].buffer.unwrap(); - immutable_buffer_mask |= 1 << slot; - } + // check for an immutable buffer + if !ep_info[var_handle].is_empty() && !storage_access_store { + let slot = ep_resources.resources[&br].buffer.unwrap(); + immutable_buffer_mask |= 1 << slot; + } - let mut dynamic_array_container_ty = var.ty; - if let naga::TypeInner::Struct { ref members, .. } = module.types[var.ty].inner - { - dynamic_array_container_ty = members.last().unwrap().ty; - } - if let naga::TypeInner::Array { - size: naga::ArraySize::Dynamic, - .. - } = module.types[dynamic_array_container_ty].inner - { - sized_bindings.push(br); + let mut dynamic_array_container_ty = var.ty; + if let naga::TypeInner::Struct { ref members, .. } = + module.types[var.ty].inner + { + dynamic_array_container_ty = members.last().unwrap().ty; + } + if let naga::TypeInner::Array { + size: naga::ArraySize::Dynamic, + .. + } = module.types[dynamic_array_container_ty].inner + { + sized_bindings.push(br); + } + } + _ => {} } } - _ => {} + + Ok(CompiledShader { + library, + function, + wg_size, + wg_memory_sizes, + sized_bindings, + immutable_buffer_mask, + }) } + ShaderModuleSource::Passthrough(ref shader) => Ok(CompiledShader { + library: shader.library.clone(), + function: shader.function.clone(), + wg_size: MTLSize { + width: shader.num_workgroups.0 as u64, + height: shader.num_workgroups.1 as u64, + depth: shader.num_workgroups.2 as u64, + }, + wg_memory_sizes: vec![], + sized_bindings: vec![], + immutable_buffer_mask: 0, + }), } - - Ok(CompiledShader { - library, - function, - wg_size, - wg_memory_sizes, - sized_bindings, - immutable_buffer_mask, - }) } fn set_buffers_mutability( diff --git a/wgpu-hal/src/metal/mod.rs b/wgpu-hal/src/metal/mod.rs index 1e7b5281240..fda7e001906 100644 --- a/wgpu-hal/src/metal/mod.rs +++ b/wgpu-hal/src/metal/mod.rs @@ -624,7 +624,8 @@ impl ops::Index for MultiStageData { naga::ShaderStage::Vertex => &self.vs, naga::ShaderStage::Fragment => &self.fs, naga::ShaderStage::Compute => &self.cs, - naga::ShaderStage::Task | naga::ShaderStage::Mesh => unreachable!(), + naga::ShaderStage::Task => &self.ts, + naga::ShaderStage::Mesh => &self.ms, } } } From d4725b1f14cd6f8c6ba00068884e0dea22a71343 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Sun, 24 Aug 2025 13:54:15 -0500 Subject: [PATCH 11/16] Small update to test shader (still blank screen) --- examples/features/src/mesh_shader/shader.metal | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/features/src/mesh_shader/shader.metal b/examples/features/src/mesh_shader/shader.metal index 65edc83e442..5c99fffc231 100644 --- a/examples/features/src/mesh_shader/shader.metal +++ b/examples/features/src/mesh_shader/shader.metal @@ -2,20 +2,20 @@ using namespace metal; struct OutVertex { float4 Position [[position]]; - float4 Color; + float4 Color [[user(locn0)]]; }; struct OutPrimitive { - float4 ColorMask [[flat]]; - bool CullPrimitive; + float4 ColorMask [[flat]] [[user(locn1)]]; + bool CullPrimitive [[primitive_culled]]; }; struct InVertex { - float4 Color; + float4 Color [[user(locn0)]]; }; struct InPrimitive { - float4 ColorMask [[flat]]; + float4 ColorMask [[flat]] [[user(locn1)]]; }; struct FragmentIn { From 7efae60bd9138c37f095c946fd4b349b0f454eb6 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Sun, 24 Aug 2025 13:56:20 -0500 Subject: [PATCH 12/16] Another quick update to the shader --- examples/features/src/mesh_shader/shader.metal | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/examples/features/src/mesh_shader/shader.metal b/examples/features/src/mesh_shader/shader.metal index 5c99fffc231..4c7da503832 100644 --- a/examples/features/src/mesh_shader/shader.metal +++ b/examples/features/src/mesh_shader/shader.metal @@ -11,7 +11,6 @@ struct OutPrimitive { }; struct InVertex { - float4 Color [[user(locn0)]]; }; struct InPrimitive { @@ -19,8 +18,8 @@ struct InPrimitive { }; struct FragmentIn { - InVertex vert; - InPrimitive prim; + float4 Color [[user(locn0)]]; + float4 ColorMask [[flat]] [[user(locn1)]]; }; struct PayloadData { @@ -74,5 +73,5 @@ void meshShader( } fragment float4 fragShader(FragmentIn data [[stage_in]]) { - return data.vert.Color * data.prim.ColorMask; + return data.Color * data.ColorMask; } From bd79d513438e30f1fe845490ddda835f6f9f46ef Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Sun, 24 Aug 2025 13:59:38 -0500 Subject: [PATCH 13/16] Made mesh shader tests get skipped on metal due to not having MSL passthrough yet --- tests/tests/wgpu-gpu/mesh_shader/mod.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/tests/wgpu-gpu/mesh_shader/mod.rs b/tests/tests/wgpu-gpu/mesh_shader/mod.rs index 4dd897129f6..ae705c92341 100644 --- a/tests/tests/wgpu-gpu/mesh_shader/mod.rs +++ b/tests/tests/wgpu-gpu/mesh_shader/mod.rs @@ -86,6 +86,9 @@ fn mesh_pipeline_build( frag: Option<&[u8]>, draw: bool, ) { + if ctx.adapter.get_info().backend != wgpu::Backend::Vulkan { + return; + } let device = &ctx.device; let (_depth_image, depth_view, depth_state) = create_depth(device); let task = task.map(|t| compile_glsl(device, t, "task")); @@ -160,6 +163,9 @@ pub enum DrawType { } fn mesh_draw(ctx: &TestingContext, draw_type: DrawType) { + if ctx.adapter.get_info().backend != wgpu::Backend::Vulkan { + return; + } let device = &ctx.device; let (_depth_image, depth_view, depth_state) = create_depth(device); let task = compile_glsl(device, BASIC_TASK, "task"); From 760de4b59e007cbde0d0d80e761d7ea839b7b476 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Sun, 24 Aug 2025 14:33:19 -0500 Subject: [PATCH 14/16] Add changelog entry --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 11fb072dcab..2f679924d06 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -110,6 +110,9 @@ This allows using precompiled shaders without manually checking which backend's - Allow disabling waiting for latency waitable object. By @marcpabst in [#7400](https://github.com/gfx-rs/wgpu/pull/7400) +#### Metal +- Add support for mesh shaders. By @SupaMaggie70Incorporated in [#8139](https://github.com/gfx-rs/wgpu/pull/8139) + ### Bug Fixes #### General From 3f56df6b4842a14a79301ec8eeeeac309deac01b Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Sun, 24 Aug 2025 15:10:47 -0500 Subject: [PATCH 15/16] Made some stuff more generic (bind groups & push constants) --- wgpu-hal/src/metal/command.rs | 300 +++++++++++++++++----------------- 1 file changed, 148 insertions(+), 152 deletions(-) diff --git a/wgpu-hal/src/metal/command.rs b/wgpu-hal/src/metal/command.rs index 542287983e9..1e4ac8d2419 100644 --- a/wgpu-hal/src/metal/command.rs +++ b/wgpu-hal/src/metal/command.rs @@ -672,168 +672,150 @@ impl crate::CommandEncoder for super::CommandEncoder { dynamic_offsets: &[wgt::DynamicOffset], ) { let bg_info = &layout.bind_group_infos[group_index as usize]; - - if let Some(ref encoder) = self.state.render { - let mut changes_sizes_buffer = false; - for index in 0..group.counters.vs.buffers { - let buf = &group.buffers[index as usize]; - let mut offset = buf.offset; - if let Some(dyn_index) = buf.dynamic_index { - offset += dynamic_offsets[dyn_index as usize] as wgt::BufferAddress; + let render_encoder = self.state.render.clone(); + let compute_encoder = self.state.compute.clone(); + let mut update_stage = + |stage: naga::ShaderStage, + render_encoder: Option<&metal::RenderCommandEncoder>, + compute_encoder: Option<&metal::ComputeCommandEncoder>| { + let buffers = match stage { + naga::ShaderStage::Vertex => group.counters.vs.buffers, + naga::ShaderStage::Fragment => group.counters.fs.buffers, + naga::ShaderStage::Task => group.counters.ts.buffers, + naga::ShaderStage::Mesh => group.counters.ms.buffers, + naga::ShaderStage::Compute => group.counters.cs.buffers, + }; + let mut changes_sizes_buffer = false; + for index in 0..buffers { + let buf = &group.buffers[index as usize]; + let mut offset = buf.offset; + if let Some(dyn_index) = buf.dynamic_index { + offset += dynamic_offsets[dyn_index as usize] as wgt::BufferAddress; + } + let a1 = (bg_info.base_resource_indices.vs.buffers + index) as u64; + let a2 = Some(buf.ptr.as_native()); + let a3 = offset; + match stage { + naga::ShaderStage::Vertex => { + render_encoder.unwrap().set_vertex_buffer(a1, a2, a3) + } + naga::ShaderStage::Fragment => { + render_encoder.unwrap().set_fragment_buffer(a1, a2, a3) + } + naga::ShaderStage::Task => { + render_encoder.unwrap().set_object_buffer(a1, a2, a3) + } + naga::ShaderStage::Mesh => { + render_encoder.unwrap().set_mesh_buffer(a1, a2, a3) + } + naga::ShaderStage::Compute => { + compute_encoder.unwrap().set_buffer(a1, a2, a3) + } + } + if let Some(size) = buf.binding_size { + let br = naga::ResourceBinding { + group: group_index, + binding: buf.binding_location, + }; + self.state.storage_buffer_length_map.insert(br, size); + changes_sizes_buffer = true; + } } - encoder.set_vertex_buffer( - (bg_info.base_resource_indices.vs.buffers + index) as u64, - Some(buf.ptr.as_native()), - offset, - ); - if let Some(size) = buf.binding_size { - let br = naga::ResourceBinding { - group: group_index, - binding: buf.binding_location, - }; - self.state.storage_buffer_length_map.insert(br, size); - changes_sizes_buffer = true; + if changes_sizes_buffer { + if let Some((index, sizes)) = self + .state + .make_sizes_buffer_update(stage, &mut self.temp.binding_sizes) + { + let a1 = index as _; + let a2 = (sizes.len() * WORD_SIZE) as u64; + let a3 = sizes.as_ptr().cast(); + match stage { + naga::ShaderStage::Vertex => { + render_encoder.unwrap().set_vertex_bytes(a1, a2, a3) + } + naga::ShaderStage::Fragment => { + render_encoder.unwrap().set_fragment_bytes(a1, a2, a3) + } + naga::ShaderStage::Task => { + render_encoder.unwrap().set_object_bytes(a1, a2, a3) + } + naga::ShaderStage::Mesh => { + render_encoder.unwrap().set_mesh_bytes(a1, a2, a3) + } + naga::ShaderStage::Compute => { + compute_encoder.unwrap().set_bytes(a1, a2, a3) + } + } + } } - } - if changes_sizes_buffer { - if let Some((index, sizes)) = self.state.make_sizes_buffer_update( - naga::ShaderStage::Vertex, - &mut self.temp.binding_sizes, - ) { - encoder.set_vertex_bytes( - index as _, - (sizes.len() * WORD_SIZE) as u64, - sizes.as_ptr().cast(), - ); + let samplers = match stage { + naga::ShaderStage::Vertex => group.counters.vs.samplers, + naga::ShaderStage::Fragment => group.counters.fs.samplers, + naga::ShaderStage::Task => group.counters.ts.samplers, + naga::ShaderStage::Mesh => group.counters.ms.samplers, + naga::ShaderStage::Compute => group.counters.cs.samplers, + }; + for index in 0..samplers { + let res = group.samplers[(group.counters.vs.samplers + index) as usize]; + let a1 = (bg_info.base_resource_indices.fs.samplers + index) as u64; + let a2 = Some(res.as_native()); + match stage { + naga::ShaderStage::Vertex => { + render_encoder.unwrap().set_vertex_sampler_state(a1, a2) + } + naga::ShaderStage::Fragment => { + render_encoder.unwrap().set_fragment_sampler_state(a1, a2) + } + naga::ShaderStage::Task => { + render_encoder.unwrap().set_object_sampler_state(a1, a2) + } + naga::ShaderStage::Mesh => { + render_encoder.unwrap().set_mesh_sampler_state(a1, a2) + } + naga::ShaderStage::Compute => { + compute_encoder.unwrap().set_sampler_state(a1, a2) + } + } } - } - changes_sizes_buffer = false; - for index in 0..group.counters.fs.buffers { - let buf = &group.buffers[(group.counters.vs.buffers + index) as usize]; - let mut offset = buf.offset; - if let Some(dyn_index) = buf.dynamic_index { - offset += dynamic_offsets[dyn_index as usize] as wgt::BufferAddress; - } - encoder.set_fragment_buffer( - (bg_info.base_resource_indices.fs.buffers + index) as u64, - Some(buf.ptr.as_native()), - offset, - ); - if let Some(size) = buf.binding_size { - let br = naga::ResourceBinding { - group: group_index, - binding: buf.binding_location, - }; - self.state.storage_buffer_length_map.insert(br, size); - changes_sizes_buffer = true; - } - } - if changes_sizes_buffer { - if let Some((index, sizes)) = self.state.make_sizes_buffer_update( - naga::ShaderStage::Fragment, - &mut self.temp.binding_sizes, - ) { - encoder.set_fragment_bytes( - index as _, - (sizes.len() * WORD_SIZE) as u64, - sizes.as_ptr().cast(), - ); + let textures = match stage { + naga::ShaderStage::Vertex => group.counters.vs.textures, + naga::ShaderStage::Fragment => group.counters.fs.textures, + naga::ShaderStage::Task => group.counters.ts.textures, + naga::ShaderStage::Mesh => group.counters.ms.textures, + naga::ShaderStage::Compute => group.counters.cs.textures, + }; + for index in 0..textures { + let res = group.textures[index as usize]; + let a1 = (bg_info.base_resource_indices.vs.textures + index) as u64; + let a2 = Some(res.as_native()); + match stage { + naga::ShaderStage::Vertex => { + render_encoder.unwrap().set_vertex_texture(a1, a2) + } + naga::ShaderStage::Fragment => { + render_encoder.unwrap().set_fragment_texture(a1, a2) + } + naga::ShaderStage::Task => { + render_encoder.unwrap().set_object_texture(a1, a2) + } + naga::ShaderStage::Mesh => render_encoder.unwrap().set_mesh_texture(a1, a2), + naga::ShaderStage::Compute => compute_encoder.unwrap().set_texture(a1, a2), + } } - } - - for index in 0..group.counters.vs.samplers { - let res = group.samplers[index as usize]; - encoder.set_vertex_sampler_state( - (bg_info.base_resource_indices.vs.samplers + index) as u64, - Some(res.as_native()), - ); - } - for index in 0..group.counters.fs.samplers { - let res = group.samplers[(group.counters.vs.samplers + index) as usize]; - encoder.set_fragment_sampler_state( - (bg_info.base_resource_indices.fs.samplers + index) as u64, - Some(res.as_native()), - ); - } - - for index in 0..group.counters.vs.textures { - let res = group.textures[index as usize]; - encoder.set_vertex_texture( - (bg_info.base_resource_indices.vs.textures + index) as u64, - Some(res.as_native()), - ); - } - for index in 0..group.counters.fs.textures { - let res = group.textures[(group.counters.vs.textures + index) as usize]; - encoder.set_fragment_texture( - (bg_info.base_resource_indices.fs.textures + index) as u64, - Some(res.as_native()), - ); - } - + }; + if let Some(encoder) = render_encoder { + update_stage(naga::ShaderStage::Vertex, Some(&encoder), None); + update_stage(naga::ShaderStage::Fragment, Some(&encoder), None); + update_stage(naga::ShaderStage::Task, Some(&encoder), None); + update_stage(naga::ShaderStage::Mesh, Some(&encoder), None); // Call useResource on all textures and buffers used indirectly so they are alive for (resource, use_info) in group.resources_to_use.iter() { encoder.use_resource_at(resource.as_native(), use_info.uses, use_info.stages); } } - - if let Some(ref encoder) = self.state.compute { - let index_base = super::ResourceData { - buffers: group.counters.vs.buffers + group.counters.fs.buffers, - samplers: group.counters.vs.samplers + group.counters.fs.samplers, - textures: group.counters.vs.textures + group.counters.fs.textures, - }; - - let mut changes_sizes_buffer = false; - for index in 0..group.counters.cs.buffers { - let buf = &group.buffers[(index_base.buffers + index) as usize]; - let mut offset = buf.offset; - if let Some(dyn_index) = buf.dynamic_index { - offset += dynamic_offsets[dyn_index as usize] as wgt::BufferAddress; - } - encoder.set_buffer( - (bg_info.base_resource_indices.cs.buffers + index) as u64, - Some(buf.ptr.as_native()), - offset, - ); - if let Some(size) = buf.binding_size { - let br = naga::ResourceBinding { - group: group_index, - binding: buf.binding_location, - }; - self.state.storage_buffer_length_map.insert(br, size); - changes_sizes_buffer = true; - } - } - if changes_sizes_buffer { - if let Some((index, sizes)) = self.state.make_sizes_buffer_update( - naga::ShaderStage::Compute, - &mut self.temp.binding_sizes, - ) { - encoder.set_bytes( - index as _, - (sizes.len() * WORD_SIZE) as u64, - sizes.as_ptr().cast(), - ); - } - } - - for index in 0..group.counters.cs.samplers { - let res = group.samplers[(index_base.samplers + index) as usize]; - encoder.set_sampler_state( - (bg_info.base_resource_indices.cs.samplers + index) as u64, - Some(res.as_native()), - ); - } - for index in 0..group.counters.cs.textures { - let res = group.textures[(index_base.textures + index) as usize]; - encoder.set_texture( - (bg_info.base_resource_indices.cs.textures + index) as u64, - Some(res.as_native()), - ); - } - + if let Some(encoder) = compute_encoder { + update_stage(naga::ShaderStage::Compute, None, Some(&encoder)); // Call useResource on all textures and buffers used indirectly so they are alive for (resource, use_info) in group.resources_to_use.iter() { if !use_info.visible_in_compute { @@ -881,6 +863,20 @@ impl crate::CommandEncoder for super::CommandEncoder { state_pc.as_ptr().cast(), ) } + if stages.contains(wgt::ShaderStages::TASK) { + self.state.render.as_ref().unwrap().set_object_bytes( + layout.push_constants_infos.ts.unwrap().buffer_index as _, + (layout.total_push_constants as usize * WORD_SIZE) as _, + state_pc.as_ptr().cast(), + ) + } + if stages.contains(wgt::ShaderStages::MESH) { + self.state.render.as_ref().unwrap().set_object_bytes( + layout.push_constants_infos.ms.unwrap().buffer_index as _, + (layout.total_push_constants as usize * WORD_SIZE) as _, + state_pc.as_ptr().cast(), + ) + } } unsafe fn insert_debug_marker(&mut self, label: &str) { From d6931d2f0a2217b8a8ba28f23e15cae0b89f54f2 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Sun, 24 Aug 2025 15:27:35 -0500 Subject: [PATCH 16/16] Applied some fixes --- wgpu-hal/src/metal/command.rs | 11 ++++++----- wgpu-hal/src/metal/device.rs | 4 ++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/wgpu-hal/src/metal/command.rs b/wgpu-hal/src/metal/command.rs index 1e4ac8d2419..2ebf80f0f26 100644 --- a/wgpu-hal/src/metal/command.rs +++ b/wgpu-hal/src/metal/command.rs @@ -1146,12 +1146,13 @@ impl crate::CommandEncoder for super::CommandEncoder { group_count_z: u32, ) { let encoder = self.state.render.as_ref().unwrap(); + let size = MTLSize { + width: group_count_x as u64, + height: group_count_y as u64, + depth: group_count_z as u64, + }; encoder.draw_mesh_threadgroups( - MTLSize { - width: group_count_x as u64, - height: group_count_y as u64, - depth: group_count_z as u64, - }, + size, self.state.stage_infos.ts.raw_wg_size, self.state.stage_infos.ms.raw_wg_size, ); diff --git a/wgpu-hal/src/metal/device.rs b/wgpu-hal/src/metal/device.rs index ca16a222efb..70753a3ff6c 100644 --- a/wgpu-hal/src/metal/device.rs +++ b/wgpu-hal/src/metal/device.rs @@ -1264,7 +1264,7 @@ impl crate::Device for super::Device { sized_bindings: ts.sized_bindings, vertex_buffer_mappings: vec![], library: Some(ts.library), - raw_wg_size: Default::default(), + raw_wg_size: ts.wg_size, work_group_memory_sizes: vec![], }); } else { @@ -1291,7 +1291,7 @@ impl crate::Device for super::Device { sized_bindings: ms.sized_bindings, vertex_buffer_mappings: vec![], library: Some(ms.library), - raw_wg_size: Default::default(), + raw_wg_size: ms.wg_size, work_group_memory_sizes: vec![], }); }