Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
326 changes: 172 additions & 154 deletions wgpu-hal/src/metal/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,176 +128,194 @@ impl super::Device {
primitive_class: MTLPrimitiveTopologyClass,
naga_stage: naga::ShaderStage,
) -> Result<CompiledShader, crate::PipelineError> {
let naga_shader = if let ShaderModuleSource::Naga(naga) = &stage.module.source {
naga
} else {
panic!("load_shader required a naga shader");
};
let stage_bit = map_naga_stage(naga_stage);
let (module, module_info) = naga::back::pipeline_constants::process_overrides(
&naga_shader.module,
&naga_shader.info,
Some((naga_stage, stage.entry_point)),
stage.constants,
)
.map_err(|e| crate::PipelineError::PipelineConstants(stage_bit, format!("MSL: {e:?}")))?;

let ep_resources = &layout.per_stage_map[naga_stage];

let bounds_check_policy = if stage.module.bounds_checks.bounds_checks {
naga::proc::BoundsCheckPolicy::Restrict
} else {
naga::proc::BoundsCheckPolicy::Unchecked
};
match stage.module.source {
ShaderModuleSource::Naga(ref naga_shader) => {
let stage_bit = map_naga_stage(naga_stage);
let (module, module_info) = naga::back::pipeline_constants::process_overrides(
&naga_shader.module,
&naga_shader.info,
Some((naga_stage, stage.entry_point)),
stage.constants,
)
.map_err(|e| {
crate::PipelineError::PipelineConstants(stage_bit, format!("MSL: {e:?}"))
})?;

let options = naga::back::msl::Options {
lang_version: match self.shared.private_caps.msl_version {
MTLLanguageVersion::V1_0 => (1, 0),
MTLLanguageVersion::V1_1 => (1, 1),
MTLLanguageVersion::V1_2 => (1, 2),
MTLLanguageVersion::V2_0 => (2, 0),
MTLLanguageVersion::V2_1 => (2, 1),
MTLLanguageVersion::V2_2 => (2, 2),
MTLLanguageVersion::V2_3 => (2, 3),
MTLLanguageVersion::V2_4 => (2, 4),
MTLLanguageVersion::V3_0 => (3, 0),
MTLLanguageVersion::V3_1 => (3, 1),
},
inline_samplers: Default::default(),
spirv_cross_compatibility: false,
fake_missing_bindings: false,
per_entry_point_map: naga::back::msl::EntryPointResourceMap::from([(
stage.entry_point.to_owned(),
ep_resources.clone(),
)]),
bounds_check_policies: naga::proc::BoundsCheckPolicies {
index: bounds_check_policy,
buffer: bounds_check_policy,
image_load: bounds_check_policy,
// TODO: support bounds checks on binding arrays
binding_array: naga::proc::BoundsCheckPolicy::Unchecked,
},
zero_initialize_workgroup_memory: stage.zero_initialize_workgroup_memory,
force_loop_bounding: stage.module.bounds_checks.force_loop_bounding,
};
let ep_resources = &layout.per_stage_map[naga_stage];

let pipeline_options = naga::back::msl::PipelineOptions {
entry_point: Some((naga_stage, stage.entry_point.to_owned())),
allow_and_force_point_size: match primitive_class {
MTLPrimitiveTopologyClass::Point => true,
_ => false,
},
vertex_pulling_transform: true,
vertex_buffer_mappings: vertex_buffer_mappings.to_vec(),
};
let bounds_check_policy = if stage.module.bounds_checks.bounds_checks {
naga::proc::BoundsCheckPolicy::Restrict
} else {
naga::proc::BoundsCheckPolicy::Unchecked
};

let (source, info) =
naga::back::msl::write_string(&module, &module_info, &options, &pipeline_options)
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("MSL: {e:?}")))?;
let options = naga::back::msl::Options {
lang_version: match self.shared.private_caps.msl_version {
MTLLanguageVersion::V1_0 => (1, 0),
MTLLanguageVersion::V1_1 => (1, 1),
MTLLanguageVersion::V1_2 => (1, 2),
MTLLanguageVersion::V2_0 => (2, 0),
MTLLanguageVersion::V2_1 => (2, 1),
MTLLanguageVersion::V2_2 => (2, 2),
MTLLanguageVersion::V2_3 => (2, 3),
MTLLanguageVersion::V2_4 => (2, 4),
MTLLanguageVersion::V3_0 => (3, 0),
MTLLanguageVersion::V3_1 => (3, 1),
},
inline_samplers: Default::default(),
spirv_cross_compatibility: false,
fake_missing_bindings: false,
per_entry_point_map: naga::back::msl::EntryPointResourceMap::from([(
stage.entry_point.to_owned(),
ep_resources.clone(),
)]),
bounds_check_policies: naga::proc::BoundsCheckPolicies {
index: bounds_check_policy,
buffer: bounds_check_policy,
image_load: bounds_check_policy,
// TODO: support bounds checks on binding arrays
binding_array: naga::proc::BoundsCheckPolicy::Unchecked,
},
zero_initialize_workgroup_memory: stage.zero_initialize_workgroup_memory,
force_loop_bounding: stage.module.bounds_checks.force_loop_bounding,
};

log::debug!(
"Naga generated shader for entry point '{}' and stage {:?}\n{}",
stage.entry_point,
naga_stage,
&source
);
let pipeline_options = naga::back::msl::PipelineOptions {
entry_point: Some((naga_stage, stage.entry_point.to_owned())),
allow_and_force_point_size: match primitive_class {
MTLPrimitiveTopologyClass::Point => true,
_ => false,
},
vertex_pulling_transform: true,
vertex_buffer_mappings: vertex_buffer_mappings.to_vec(),
};

let options = metal::CompileOptions::new();
options.set_language_version(self.shared.private_caps.msl_version);
let (source, info) = naga::back::msl::write_string(
&module,
&module_info,
&options,
&pipeline_options,
)
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("MSL: {e:?}")))?;

if self.shared.private_caps.supports_preserve_invariance {
options.set_preserve_invariance(true);
}
log::debug!(
"Naga generated shader for entry point '{}' and stage {:?}\n{}",
stage.entry_point,
naga_stage,
&source
);

let library = self
.shared
.device
.lock()
.new_library_with_source(source.as_ref(), &options)
.map_err(|err| {
log::warn!("Naga generated shader:\n{source}");
crate::PipelineError::Linkage(stage_bit, format!("Metal: {err}"))
})?;

let ep_index = module
.entry_points
.iter()
.position(|ep| ep.stage == naga_stage && ep.name == stage.entry_point)
.ok_or(crate::PipelineError::EntryPoint(naga_stage))?;
let ep = &module.entry_points[ep_index];
let translated_ep_name = info.entry_point_names[0]
.as_ref()
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("{e}")))?;

let wg_size = MTLSize {
width: ep.workgroup_size[0] as _,
height: ep.workgroup_size[1] as _,
depth: ep.workgroup_size[2] as _,
};
let options = metal::CompileOptions::new();
options.set_language_version(self.shared.private_caps.msl_version);

let function = library
.get_function(translated_ep_name, None)
.map_err(|e| {
log::error!("get_function: {e:?}");
crate::PipelineError::EntryPoint(naga_stage)
})?;

// collect sizes indices, immutable buffers, and work group memory sizes
let ep_info = &module_info.get_entry_point(ep_index);
let mut wg_memory_sizes = Vec::new();
let mut sized_bindings = Vec::new();
let mut immutable_buffer_mask = 0;
for (var_handle, var) in module.global_variables.iter() {
match var.space {
naga::AddressSpace::WorkGroup => {
if !ep_info[var_handle].is_empty() {
let size = module.types[var.ty].inner.size(module.to_ctx());
wg_memory_sizes.push(size);
}
if self.shared.private_caps.supports_preserve_invariance {
options.set_preserve_invariance(true);
}
naga::AddressSpace::Uniform | naga::AddressSpace::Storage { .. } => {
let br = match var.binding {
Some(br) => br,
None => continue,
};
let storage_access_store = match var.space {
naga::AddressSpace::Storage { access } => {
access.contains(naga::StorageAccess::STORE)

let library = self
.shared
.device
.lock()
.new_library_with_source(source.as_ref(), &options)
.map_err(|err| {
log::warn!("Naga generated shader:\n{source}");
crate::PipelineError::Linkage(stage_bit, format!("Metal: {err}"))
})?;

let ep_index = module
.entry_points
.iter()
.position(|ep| ep.stage == naga_stage && ep.name == stage.entry_point)
.ok_or(crate::PipelineError::EntryPoint(naga_stage))?;
let ep = &module.entry_points[ep_index];
let translated_ep_name = info.entry_point_names[0]
.as_ref()
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("{e}")))?;

let wg_size = MTLSize {
width: ep.workgroup_size[0] as _,
height: ep.workgroup_size[1] as _,
depth: ep.workgroup_size[2] as _,
};

let function = library
.get_function(translated_ep_name, None)
.map_err(|e| {
log::error!("get_function: {e:?}");
crate::PipelineError::EntryPoint(naga_stage)
})?;

// collect sizes indices, immutable buffers, and work group memory sizes
let ep_info = &module_info.get_entry_point(ep_index);
let mut wg_memory_sizes = Vec::new();
let mut sized_bindings = Vec::new();
let mut immutable_buffer_mask = 0;
for (var_handle, var) in module.global_variables.iter() {
match var.space {
naga::AddressSpace::WorkGroup => {
if !ep_info[var_handle].is_empty() {
let size = module.types[var.ty].inner.size(module.to_ctx());
wg_memory_sizes.push(size);
}
}
_ => false,
};
naga::AddressSpace::Uniform | naga::AddressSpace::Storage { .. } => {
let br = match var.binding {
Some(br) => br,
None => continue,
};
let storage_access_store = match var.space {
naga::AddressSpace::Storage { access } => {
access.contains(naga::StorageAccess::STORE)
}
_ => false,
};

// check for an immutable buffer
if !ep_info[var_handle].is_empty() && !storage_access_store {
let slot = ep_resources.resources[&br].buffer.unwrap();
immutable_buffer_mask |= 1 << slot;
}
// check for an immutable buffer
if !ep_info[var_handle].is_empty() && !storage_access_store {
let slot = ep_resources.resources[&br].buffer.unwrap();
immutable_buffer_mask |= 1 << slot;
}

let mut dynamic_array_container_ty = var.ty;
if let naga::TypeInner::Struct { ref members, .. } = module.types[var.ty].inner
{
dynamic_array_container_ty = members.last().unwrap().ty;
}
if let naga::TypeInner::Array {
size: naga::ArraySize::Dynamic,
..
} = module.types[dynamic_array_container_ty].inner
{
sized_bindings.push(br);
let mut dynamic_array_container_ty = var.ty;
if let naga::TypeInner::Struct { ref members, .. } =
module.types[var.ty].inner
{
dynamic_array_container_ty = members.last().unwrap().ty;
}
if let naga::TypeInner::Array {
size: naga::ArraySize::Dynamic,
..
} = module.types[dynamic_array_container_ty].inner
{
sized_bindings.push(br);
}
}
_ => {}
}
}
_ => {}

Ok(CompiledShader {
library,
function,
wg_size,
wg_memory_sizes,
sized_bindings,
immutable_buffer_mask,
})
}
ShaderModuleSource::Passthrough(ref shader) => Ok(CompiledShader {
library: shader.library.clone(),
function: shader.function.clone(),
wg_size: MTLSize {
width: shader.num_workgroups.0 as u64,
height: shader.num_workgroups.1 as u64,
depth: shader.num_workgroups.2 as u64,
},
wg_memory_sizes: vec![],
sized_bindings: vec![],
immutable_buffer_mask: 0,
}),
}

Ok(CompiledShader {
library,
function,
wg_size,
wg_memory_sizes,
sized_bindings,
immutable_buffer_mask,
})
}

fn set_buffers_mutability(
Expand Down
Loading