Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
1c90d19
Initial commit
SupaMaggie70Incorporated Aug 14, 2025
8c3e550
Other initial changes
SupaMaggie70Incorporated Aug 14, 2025
85bbc5a
Updated shader snapshots
SupaMaggie70Incorporated Aug 14, 2025
ccf8467
Added new HLSL limitation
SupaMaggie70Incorporated Aug 17, 2025
e55c02f
Moved error to global variable error
SupaMaggie70Incorporated Aug 17, 2025
f3a31a4
Merge branch 'trunk' into mesh-shading/naga-ir
SupaMaggie70Incorporated Aug 17, 2025
0f6da75
Added docs to per_primitive
SupaMaggie70Incorporated Aug 20, 2025
3017214
Added a little bit more docs here and there in IR
SupaMaggie70Incorporated Aug 20, 2025
19b55b5
Merge branch 'trunk' into mesh-shading/naga-ir
SupaMaggie70Incorporated Aug 20, 2025
198437b
Adding validation to ensure that task shaders have a task payload
SupaMaggie70Incorporated Aug 20, 2025
64000e4
Updated spec to reflect the change to payload variables
SupaMaggie70Incorporated Aug 20, 2025
0575e98
Merge branch 'trunk' into mesh-shading/naga-ir
SupaMaggie70Incorporated Aug 22, 2025
b572ec7
Updated the mesh shading spec because it was goofy
SupaMaggie70Incorporated Aug 24, 2025
34d0411
Merge branch 'trunk' into mesh-shading/naga-ir
SupaMaggie70Incorporated Aug 24, 2025
02664e4
Merge branch 'trunk' into mesh-shading/naga-ir
SupaMaggie70Incorporated Aug 24, 2025
7bec4dd
some doc tweaks
jimblandy Aug 25, 2025
2fcb853
Tried to clarify docs a little
SupaMaggie70Incorporated Aug 25, 2025
3009b5a
Merge branch 'trunk' into mesh-shading/naga-ir
SupaMaggie70Incorporated Aug 25, 2025
8bfe106
Tried to update spec
SupaMaggie70Incorporated Aug 25, 2025
6ccaeec
Removed a warning
SupaMaggie70Incorporated Aug 25, 2025
5b7ba11
Addressed comment about docs mistake
SupaMaggie70Incorporated Aug 25, 2025
29c6972
Merge branch 'trunk' into mesh-shading/naga-ir
SupaMaggie70Incorporated Aug 30, 2025
63fa8b5
Merge branch 'trunk' into mesh-shading/naga-ir
SupaMaggie70Incorporated Sep 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions docs/api-specs/mesh_shading.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,32 +80,36 @@ This shader stage can be selected by marking a function with `@task`. Task shade

The output of this determines how many workgroups of mesh shaders will be dispatched. Once dispatched, global id variables will be local to the task shader workgroup dispatch, and mesh shaders won't know the position of their dispatch among all mesh shader dispatches unless this is passed through the payload. The output may be zero to skip dispatching any mesh shader workgroups for the task shader workgroup.

If task shaders are marked with `@payload(someVar)`, where `someVar` is global variable declared like `var<workgroup> someVar: <type>`, task shaders may write to `someVar`. This payload is passed to the mesh shader workgroup that is invoked. The mesh shader can skip declaring `@payload` to ignore this input.
If task shaders are marked with `@payload(someVar)`, where `someVar` is global variable declared like `var<task_payload> someVar: <type>`, task shaders may use `someVar` as if it is a read-write workgroup storage variable. This payload is passed to the mesh shader workgroup that is invoked. The mesh shader can skip declaring `@payload` to ignore this input.

### Mesh shader
This shader stage can be selected by marking a function with `@mesh`. Mesh shaders must not return anything.

Mesh shaders can be marked with `@payload(someVar)` similar to task shaders. Unlike task shaders, mesh shaders cannot write to this workgroup memory. Declaring `@payload` in a pipeline with no task shader, in a pipeline with a task shader that doesn't declare `@payload`, or in a task shader with an `@payload` that is statically sized and smaller than the mesh shader payload is illegal.
Mesh shaders can be marked with `@payload(someVar)` similar to task shaders. Unlike task shaders, mesh shaders cannot write to this memory. Declaring `@payload` in a pipeline with no task shader, in a pipeline with a task shader that doesn't declare `@payload`, or in a task shader with an `@payload` that is statically sized and smaller than the mesh shader payload is illegal.

Mesh shaders must be marked with `@vertex_output(OutputType, numOutputs)`, where `numOutputs` is the maximum number of vertices to be output by a mesh shader, and `OutputType` is the data associated with vertices, similar to a standard vertex shader output.
Mesh shaders must be marked with `@vertex_output(OutputType, numOutputs)`, where `numOutputs` is the maximum number of vertices to be output by a mesh shader, and `OutputType` is the data associated with vertices, similar to a standard vertex shader output, and must be a struct.

Mesh shaders must also be marked with `@primitive_output(OutputType, numOutputs)`, which is similar to `@vertex_output` except it describes the primitive outputs.

### Mesh shader outputs

Primitive outputs from mesh shaders have some additional builtins they can set. These include `@builtin(cull_primitive)`, which must be a boolean value. If this is set to true, then the primitive is skipped during rendering.
Vertex outputs from mesh shaders function identically to outputs of vertex shaders, and as such must have a field with `@builtin(position)`.

Primitive outputs from mesh shaders have some additional builtins they can set. These include `@builtin(cull_primitive)`, which must be a boolean value. If this is set to true, then the primitive is skipped during rendering. All non-builtin primitive outputs must be decorated with `@per_primitive`.

Mesh shader primitive outputs must also specify exactly one of `@builtin(triangle_indices)`, `@builtin(line_indices)`, or `@builtin(point_index)`. This determines the output topology of the mesh shader, and must match the output topology of the pipeline descriptor the mesh shader is used with. These must be of type `vec3<u32>`, `vec2<u32>`, and `u32` respectively. When setting this, each of the indices must be less than the number of vertices declared in `setMeshOutputs`.

Additionally, the `@location` attributes from the vertex and primitive outputs can't overlap.

Before setting any vertices or indices, or exiting, the mesh shader must call `setMeshOutputs(numVertices: u32, numIndices: u32)`, which declares the number of vertices and indices that will be written to. These must be less than the corresponding maximums set in `@vertex_output` and `@primitive_output`. The mesh shader must then write to exactly these numbers of vertices and primitives.
Before setting any vertices or indices, or exiting, the mesh shader must call `setMeshOutputs(numVertices: u32, numIndices: u32)`, which declares the number of vertices and indices that will be written to. These must be less than the corresponding maximums set in `@vertex_output` and `@primitive_output`. The mesh shader must then write to exactly these numbers of vertices and primitives. A varying member with `@per_primitive` cannot be used in function interfaces except as the primitive output for mesh shaders or as input for fragment shaders.

The mesh shader can write to vertices using the `setVertex(idx: u32, vertex: VertexOutput)` where `VertexOutput` is replaced with the vertex type declared in `@vertex_output`, and `idx` is the index of the vertex to write. Similarly, the mesh shader can write to vertices using `setPrimitive(idx: u32, primitive: PrimitiveOutput)`. These can be written to multiple times, however unsynchronized writes are undefined behavior. The primitives and indices are shared across the entire mesh shader workgroup.

### Fragment shader

Fragment shaders may now be passed the primitive info from a mesh shader the same was as they are passed vertex inputs, for example `fn fs_main(vertex: VertexOutput, primitive: PrimitiveOutput)`. The primitive state is part of the fragment input and must match the output of the mesh shader in the pipeline.
Fragment shaders can access vertex output data as if it is from a vertex shader. They can also access primitive output data, provided the input is decorated with `@per_primitive`. The `@per_primitive` attribute can be applied to a value directly, such as `@per_primitive @location(1) value: vec4<f32>`, to a struct such as `@per_primitive primitive_input: PrimitiveInput` where `PrimitiveInput` is a struct containing fields decorated with `@location` and `@builtin`, or to members of a struct that are themselves decorated with `@location` or `@builtin`.

The primitive state is part of the fragment input and must match the output of the mesh shader in the pipeline. Using `@per_primitive` also requires enabling the mesh shader extension. Additionally, the locations of vertex and primitive input cannot overlap.

### Full example

Expand All @@ -115,9 +119,9 @@ The following is a full example of WGSL shaders that could be used to create a m
enable mesh_shading;

const positions = array(
vec4(0.,-1.,0.,1.),
vec4(-1.,1.,0.,1.),
vec4(1.,1.,0.,1.)
vec4(0.,1.,0.,1.),
vec4(-1.,-1.,0.,1.),
vec4(1.,-1.,0.,1.)
);
const colors = array(
vec4(0.,1.,0.,1.),
Expand All @@ -128,7 +132,7 @@ struct TaskPayload {
colorMask: vec4<f32>,
visible: bool,
}
var<workgroup> taskPayload: TaskPayload;
var<task_payload> taskPayload: TaskPayload;
var<workgroup> workgroupData: f32;
struct VertexOutput {
@builtin(position) position: vec4<f32>,
Expand All @@ -137,14 +141,12 @@ struct VertexOutput {
struct PrimitiveOutput {
@builtin(triangle_indices) index: vec3<u32>,
@builtin(cull_primitive) cull: bool,
@location(1) colorMask: vec4<f32>,
@per_primitive @location(1) colorMask: vec4<f32>,
}
struct PrimitiveInput {
@location(1) colorMask: vec4<f32>,
@per_primitive @location(1) colorMask: vec4<f32>,
}
fn test_function(input: u32) {

}
@task
@payload(taskPayload)
@workgroup_size(1)
Expand All @@ -163,8 +165,6 @@ fn ms_main(@builtin(local_invocation_index) index: u32, @builtin(global_invocati
workgroupData = 2.0;
var v: VertexOutput;

test_function(1);

v.position = positions[0];
v.color = colors[0] * taskPayload.colorMask;
setVertex(0, v);
Expand Down
25 changes: 25 additions & 0 deletions naga-cli/src/bin/naga.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ struct Args {
#[argh(option)]
shader_model: Option<ShaderModelArg>,

/// the SPIR-V version to use if targeting SPIR-V
///
/// For example, 1.0, 1.4, etc
#[argh(option)]
spirv_version: Option<SpirvVersionArg>,

/// the shader stage, for example 'frag', 'vert', or 'compute'.
/// if the shader stage is unspecified it will be derived from
/// the file extension.
Expand Down Expand Up @@ -189,6 +195,22 @@ impl FromStr for ShaderModelArg {
}
}

#[derive(Debug, Clone)]
struct SpirvVersionArg(u8, u8);

impl FromStr for SpirvVersionArg {
type Err = String;

fn from_str(s: &str) -> Result<Self, Self::Err> {
let dot = s
.find(".")
.ok_or_else(|| "Missing dot separator".to_owned())?;
let major = s[..dot].parse::<u8>().map_err(|e| e.to_string())?;
let minor = s[dot + 1..].parse::<u8>().map_err(|e| e.to_string())?;
Ok(Self(major, minor))
}
}

/// Newtype so we can implement [`FromStr`] for `ShaderSource`.
#[derive(Debug, Clone, Copy)]
struct ShaderStage(naga::ShaderStage);
Expand Down Expand Up @@ -465,6 +487,9 @@ fn run() -> anyhow::Result<()> {
if let Some(ref version) = args.metal_version {
params.msl.lang_version = version.0;
}
if let Some(ref version) = args.spirv_version {
params.spv_out.lang_version = (version.0, version.1);
}
params.keep_coordinate_space = args.keep_coordinate_space;

params.dot.cfg_only = args.dot_cfg_only;
Expand Down
19 changes: 19 additions & 0 deletions naga/src/back/dot/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,25 @@ impl StatementGraph {
crate::RayQueryFunction::Terminate => "RayQueryTerminate",
}
}
S::MeshFunction(crate::MeshFunction::SetMeshOutputs {
vertex_count,
primitive_count,
}) => {
self.dependencies.push((id, vertex_count, "vertex_count"));
self.dependencies
.push((id, primitive_count, "primitive_count"));
"SetMeshOutputs"
}
S::MeshFunction(crate::MeshFunction::SetVertex { index, value }) => {
self.dependencies.push((id, index, "index"));
self.dependencies.push((id, value, "value"));
"SetVertex"
}
S::MeshFunction(crate::MeshFunction::SetPrimitive { index, value }) => {
self.dependencies.push((id, index, "index"));
self.dependencies.push((id, value, "value"));
"SetPrimitive"
}
S::SubgroupBallot { result, predicate } => {
if let Some(predicate) = predicate {
self.dependencies.push((id, predicate, "predicate"));
Expand Down
1 change: 1 addition & 0 deletions naga/src/back/glsl/features.rs
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,7 @@ impl<W> Writer<'_, W> {
interpolation,
sampling,
blend_src,
per_primitive: _,
} => {
if interpolation == Some(Interpolation::Linear) {
self.features.request(Features::NOPERSPECTIVE_QUALIFIER);
Expand Down
23 changes: 22 additions & 1 deletion naga/src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ impl crate::AddressSpace {
| crate::AddressSpace::Uniform
| crate::AddressSpace::Storage { .. }
| crate::AddressSpace::Handle
| crate::AddressSpace::PushConstant => false,
| crate::AddressSpace::PushConstant
| crate::AddressSpace::TaskPayload => false,
}
}
}
Expand Down Expand Up @@ -1300,6 +1301,9 @@ impl<'a, W: Write> Writer<'a, W> {
crate::AddressSpace::Storage { .. } => {
self.write_interface_block(handle, global)?;
}
crate::AddressSpace::TaskPayload => {
self.write_interface_block(handle, global)?;
}
// A global variable in the `Function` address space is a
// contradiction in terms.
crate::AddressSpace::Function => unreachable!(),
Expand Down Expand Up @@ -1614,6 +1618,7 @@ impl<'a, W: Write> Writer<'a, W> {
interpolation,
sampling,
blend_src,
per_primitive: _,
} => (location, interpolation, sampling, blend_src),
crate::Binding::BuiltIn(built_in) => {
match built_in {
Expand Down Expand Up @@ -1732,6 +1737,7 @@ impl<'a, W: Write> Writer<'a, W> {
interpolation: None,
sampling: None,
blend_src,
per_primitive: false,
},
stage: self.entry_point.stage,
options: VaryingOptions::from_writer_options(self.options, output),
Expand Down Expand Up @@ -2669,6 +2675,11 @@ impl<'a, W: Write> Writer<'a, W> {
self.write_image_atomic(ctx, image, coordinate, array_index, fun, value)?
}
Statement::RayQuery { .. } => unreachable!(),
Statement::MeshFunction(
crate::MeshFunction::SetMeshOutputs { .. }
| crate::MeshFunction::SetVertex { .. }
| crate::MeshFunction::SetPrimitive { .. },
) => unreachable!(),
Statement::SubgroupBallot { result, predicate } => {
write!(self.out, "{level}")?;
let res_name = Baked(result).to_string();
Expand Down Expand Up @@ -5247,6 +5258,15 @@ const fn glsl_built_in(built_in: crate::BuiltIn, options: VaryingOptions) -> &'s
Bi::SubgroupId => "gl_SubgroupID",
Bi::SubgroupSize => "gl_SubgroupSize",
Bi::SubgroupInvocationId => "gl_SubgroupInvocationID",
// mesh
// TODO: figure out how to map these to glsl things as glsl treats them as arrays
Bi::CullPrimitive
| Bi::PointIndex
| Bi::LineIndices
| Bi::TriangleIndices
| Bi::MeshTaskSize => {
unimplemented!()
}
}
}

Expand All @@ -5262,6 +5282,7 @@ const fn glsl_storage_qualifier(space: crate::AddressSpace) -> Option<&'static s
As::Handle => Some("uniform"),
As::WorkGroup => Some("shared"),
As::PushConstant => Some("uniform"),
As::TaskPayload => unreachable!(),
}
}

Expand Down
3 changes: 3 additions & 0 deletions naga/src/back/hlsl/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,9 @@ impl crate::BuiltIn {
Self::PointSize | Self::ViewIndex | Self::PointCoord | Self::DrawID => {
return Err(Error::Custom(format!("Unsupported builtin {self:?}")))
}
Self::CullPrimitive => "SV_CullPrimitive",
Self::PointIndex | Self::LineIndices | Self::TriangleIndices => unimplemented!(),
Self::MeshTaskSize => unreachable!(),
})
}
}
Expand Down
3 changes: 2 additions & 1 deletion naga/src/back/hlsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,8 @@ impl crate::ShaderStage {
Self::Vertex => "vs",
Self::Fragment => "ps",
Self::Compute => "cs",
Self::Task | Self::Mesh => unreachable!(),
Self::Task => "ts",
Self::Mesh => "ms",
}
}
}
Expand Down
19 changes: 17 additions & 2 deletions naga/src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {

self.write_wrapped_functions(module, &ctx)?;

if ep.stage == ShaderStage::Compute {
if ep.stage.compute_like() {
// HLSL is calling workgroup size "num threads"
let num_threads = ep.workgroup_size;
writeln!(
Expand Down Expand Up @@ -967,6 +967,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
self.write_type(module, global.ty)?;
""
}
crate::AddressSpace::TaskPayload => unimplemented!(),
crate::AddressSpace::Uniform => {
// constant buffer declarations are expected to be inlined, e.g.
// `cbuffer foo: register(b0) { field1: type1; }`
Expand Down Expand Up @@ -2599,6 +2600,19 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
writeln!(self.out, ".Abort();")?;
}
},
Statement::MeshFunction(crate::MeshFunction::SetMeshOutputs {
vertex_count,
primitive_count,
}) => {
write!(self.out, "{level}SetMeshOutputCounts(")?;
self.write_expr(module, vertex_count, func_ctx)?;
write!(self.out, ", ")?;
self.write_expr(module, primitive_count, func_ctx)?;
write!(self.out, ");")?;
}
Statement::MeshFunction(
crate::MeshFunction::SetVertex { .. } | crate::MeshFunction::SetPrimitive { .. },
) => unimplemented!(),
Statement::SubgroupBallot { result, predicate } => {
write!(self.out, "{level}")?;
let name = Baked(result).to_string();
Expand Down Expand Up @@ -3076,7 +3090,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
crate::AddressSpace::Function
| crate::AddressSpace::Private
| crate::AddressSpace::WorkGroup
| crate::AddressSpace::PushConstant,
| crate::AddressSpace::PushConstant
| crate::AddressSpace::TaskPayload,
)
| None => true,
Some(crate::AddressSpace::Uniform) => {
Expand Down
5 changes: 5 additions & 0 deletions naga/src/back/msl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,7 @@ impl Options {
interpolation,
sampling,
blend_src,
per_primitive: _,
} => match mode {
LocationMode::VertexInput => Ok(ResolvedBinding::Attribute(location)),
LocationMode::FragmentOutput => {
Expand Down Expand Up @@ -651,6 +652,10 @@ impl ResolvedBinding {
Bi::CullDistance | Bi::ViewIndex | Bi::DrawID => {
return Err(Error::UnsupportedBuiltIn(built_in))
}
Bi::CullPrimitive => "primitive_culled",
// TODO: figure out how to make this written as a function call
Bi::PointIndex | Bi::LineIndices | Bi::TriangleIndices => unimplemented!(),
Bi::MeshTaskSize => unreachable!(),
};
write!(out, "{name}")?;
}
Expand Down
Loading
Loading