Skip to content

Commit e0d82ed

Browse files
Reduce compute shading ceremony, remove branches in compute shaders
1 parent 6446e92 commit e0d82ed

File tree

8 files changed

+344
-124
lines changed

8 files changed

+344
-124
lines changed

src/renderer/app_state.rs

Lines changed: 54 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
use crate::{
22
image::grammar::Image,
33
renderer::{
4+
compute_effect::ComputeEffect,
45
draw_uniform::DrawUniform,
56
feature_uniform::{FeatureUniform, TransformAction},
67
gpu_state::GpuResourceAllocator,
78
mouse_state::MouseState,
89
shader::{Shader, TextureResource},
910
shape::{compute_distance, Circle, EditorState},
1011
shape_uniform::{CircleData, ShapeUniform, MAX_CIRCLES},
12+
Texture,
1113
},
1214
};
1315
use anyhow::Result;
@@ -39,9 +41,9 @@ pub struct AppState<'a> {
3941
pub shape_uniform: ShapeUniform,
4042
pub circle_storage_buffer: wgpu::Buffer,
4143

42-
pub color_correct_pipeline: wgpu::ComputePipeline,
43-
pub color_correct_bind_group: wgpu::BindGroup,
44-
pub output_texture: crate::renderer::Texture,
44+
pub gamma_effect: ComputeEffect,
45+
pub grayscale_effect: ComputeEffect,
46+
pub invert_effect: ComputeEffect,
4547
}
4648

4749
impl<'a> AppState<'a> {
@@ -72,43 +74,36 @@ impl<'a> AppState<'a> {
7274
let circle_storage_buffer =
7375
gpu_allocator.create_storage_buffer("circle_storage", &empty_circles)?;
7476

75-
let color_correct_pipeline = gpu_allocator.create_compute_pipeline(
76-
"color_correct_compute",
77-
include_str!("color_correct_compute.wgsl"),
78-
"cs_main",
79-
);
80-
81-
let output_texture =
82-
gpu_allocator.create_storage_texture("processed_texture", size.width, size.height);
83-
84-
let color_correct_bind_group = {
85-
let bind_group_layout = color_correct_pipeline.get_bind_group_layout(0);
86-
gpu_allocator
87-
.device
88-
.create_bind_group(&wgpu::BindGroupDescriptor {
89-
layout: &bind_group_layout,
90-
entries: &[
91-
wgpu::BindGroupEntry {
92-
binding: 0,
93-
resource: wgpu::BindingResource::TextureView(
94-
&image_texture_resource.resource.view,
95-
),
96-
},
97-
wgpu::BindGroupEntry {
98-
binding: 1,
99-
resource: wgpu::BindingResource::TextureView(&output_texture.view),
100-
},
101-
wgpu::BindGroupEntry {
102-
binding: 2,
103-
resource: feature_uniform_resource.resource.as_entire_binding(),
104-
},
105-
],
106-
label: Some("color_correct_bind_group"),
107-
})
108-
};
77+
let texture_a = gpu_allocator.create_storage_texture("texture_a", size.width, size.height);
78+
79+
// individual compute effects
80+
let gamma_effect = ComputeEffect::builder("gamma")
81+
.with_shader(include_str!("gamma_correct_compute.wgsl"))
82+
.with_uniform(feature_uniform.gamma())
83+
.build(
84+
&gpu_allocator.device,
85+
&image_texture_resource.resource.view,
86+
&texture_a.view,
87+
)?;
88+
89+
let grayscale_effect = ComputeEffect::builder("grayscale")
90+
.with_shader(include_str!("grayscale_compute.wgsl"))
91+
.build(
92+
&gpu_allocator.device,
93+
&image_texture_resource.resource.view,
94+
&texture_a.view,
95+
)?;
96+
97+
let invert_effect = ComputeEffect::builder("invert")
98+
.with_shader(include_str!("invert_compute.wgsl"))
99+
.build(
100+
&gpu_allocator.device,
101+
&image_texture_resource.resource.view,
102+
&texture_a.view,
103+
)?;
109104

110105
let processed_texture_resource = gpu_allocator
111-
.create_texture_resource_from_existing("processed_texture_ref", &output_texture);
106+
.create_texture_resource_from_existing("processed_texture_ref", &texture_a);
112107

113108
let shape_texture_for_image = gpu_allocator.create_texture_resource_from_existing(
114109
"shape_texture_ref",
@@ -147,9 +142,9 @@ impl<'a> AppState<'a> {
147142
shape_render_texture,
148143
shape_uniform,
149144
circle_storage_buffer,
150-
color_correct_pipeline,
151-
color_correct_bind_group,
152-
output_texture,
145+
gamma_effect,
146+
grayscale_effect,
147+
invert_effect,
153148
})
154149
}
155150

@@ -428,6 +423,10 @@ impl<'a> AppState<'a> {
428423
}
429424

430425
pub(crate) fn update(&mut self) {
426+
// Update gamma effect uniform
427+
self.gamma_effect
428+
.update_uniform(&self.gpu_allocator.queue, self.feature_uniform.gamma());
429+
431430
// Update image shader uniforms
432431
let uniform_resources = &self.image_shader.uniform_resources;
433432
self.gpu_allocator
@@ -472,19 +471,23 @@ impl<'a> AppState<'a> {
472471
pub(crate) fn render(&self) -> Result<(), wgpu::SurfaceError> {
473472
let (output, view, mut encoder) = self.gpu_allocator.begin_frame()?;
474473

475-
// compute pass 1: apply color corrections
474+
// Compute shader pass
476475
{
477-
let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
478-
label: Some("color_correct_compute_pass"),
479-
timestamp_writes: None,
480-
});
481-
482-
compute_pass.set_pipeline(&self.color_correct_pipeline);
483-
compute_pass.set_bind_group(0, &self.color_correct_bind_group, &[]);
484-
476+
// todo, these color effects work independently from each other
477+
// how do we pass intermediate textures?
485478
let workgroup_count_x = (self.size.width + 15) / 16;
486479
let workgroup_count_y = (self.size.height + 15) / 16;
487-
compute_pass.dispatch_workgroups(workgroup_count_x, workgroup_count_y, 1);
480+
481+
if self.feature_uniform.invert() {
482+
self.invert_effect
483+
.dispatch(&mut encoder, workgroup_count_x, workgroup_count_y);
484+
} else if self.feature_uniform.grayscale() {
485+
self.grayscale_effect
486+
.dispatch(&mut encoder, workgroup_count_x, workgroup_count_y);
487+
} else {
488+
self.gamma_effect
489+
.dispatch(&mut encoder, workgroup_count_x, workgroup_count_y);
490+
}
488491
}
489492

490493
// First pass: Render shapes to shape texture

src/renderer/compute_effect.rs

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
use anyhow::Result;
2+
use wgpu::util::DeviceExt;
3+
4+
// a compute shader effect that can be applied to textures
5+
#[derive(Debug)]
6+
pub struct ComputeEffect {
7+
pub pipeline: wgpu::ComputePipeline,
8+
pub bind_group: wgpu::BindGroup,
9+
uniform_buffer: Option<wgpu::Buffer>,
10+
}
11+
12+
impl ComputeEffect {
13+
pub fn builder<'a>(label: &'a str) -> ComputeEffectBuilder<'a> {
14+
ComputeEffectBuilder {
15+
label,
16+
shader_source: None,
17+
entry_point: "main",
18+
uniform_data: None,
19+
}
20+
}
21+
22+
pub fn update_uniform<T: bytemuck::Pod>(&self, queue: &wgpu::Queue, data: T) {
23+
if let Some(buffer) = &self.uniform_buffer {
24+
queue.write_buffer(buffer, 0, bytemuck::cast_slice(&[data]));
25+
}
26+
}
27+
28+
pub fn dispatch(
29+
&self,
30+
encoder: &mut wgpu::CommandEncoder,
31+
workgroup_count_x: u32,
32+
workgroup_count_y: u32,
33+
) {
34+
let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
35+
label: None,
36+
timestamp_writes: None,
37+
});
38+
39+
compute_pass.set_pipeline(&self.pipeline);
40+
compute_pass.set_bind_group(0, &self.bind_group, &[]);
41+
compute_pass.dispatch_workgroups(workgroup_count_x, workgroup_count_y, 1);
42+
}
43+
}
44+
45+
pub struct ComputeEffectBuilder<'a> {
46+
label: &'a str,
47+
shader_source: Option<&'a str>,
48+
entry_point: &'a str,
49+
uniform_data: Option<Vec<u8>>,
50+
}
51+
52+
impl<'a> ComputeEffectBuilder<'a> {
53+
pub fn with_shader(mut self, source: &'a str) -> Self {
54+
self.shader_source = Some(source);
55+
self
56+
}
57+
58+
pub fn with_entry_point(mut self, entry_point: &'a str) -> Self {
59+
self.entry_point = entry_point;
60+
self
61+
}
62+
63+
pub fn with_uniform<T: bytemuck::Pod>(mut self, data: T) -> Self {
64+
self.uniform_data = Some(bytemuck::cast_slice(&[data]).to_vec());
65+
self
66+
}
67+
68+
pub fn build(
69+
self,
70+
device: &wgpu::Device,
71+
input_texture: &wgpu::TextureView,
72+
output_texture: &wgpu::TextureView,
73+
) -> Result<ComputeEffect> {
74+
let shader_source = self
75+
.shader_source
76+
.ok_or_else(|| anyhow::anyhow!("Shader source is required"))?;
77+
78+
let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
79+
label: Some(self.label),
80+
source: wgpu::ShaderSource::Wgsl(shader_source.into()),
81+
});
82+
83+
let uniform_buffer = self.uniform_data.as_ref().map(|data| {
84+
device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
85+
label: Some(&format!("{}_uniform", self.label)),
86+
contents: data,
87+
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
88+
})
89+
});
90+
91+
let mut bind_group_layout_entries = vec![
92+
wgpu::BindGroupLayoutEntry {
93+
binding: 0,
94+
visibility: wgpu::ShaderStages::COMPUTE,
95+
ty: wgpu::BindingType::Texture {
96+
multisampled: false,
97+
view_dimension: wgpu::TextureViewDimension::D2,
98+
sample_type: wgpu::TextureSampleType::Float { filterable: false },
99+
},
100+
count: None,
101+
},
102+
// output storage texture
103+
wgpu::BindGroupLayoutEntry {
104+
binding: 1,
105+
visibility: wgpu::ShaderStages::COMPUTE,
106+
ty: wgpu::BindingType::StorageTexture {
107+
access: wgpu::StorageTextureAccess::WriteOnly,
108+
format: wgpu::TextureFormat::Rgba8Unorm,
109+
view_dimension: wgpu::TextureViewDimension::D2,
110+
},
111+
count: None,
112+
},
113+
];
114+
115+
if uniform_buffer.is_some() {
116+
bind_group_layout_entries.push(wgpu::BindGroupLayoutEntry {
117+
binding: 2,
118+
visibility: wgpu::ShaderStages::COMPUTE,
119+
ty: wgpu::BindingType::Buffer {
120+
ty: wgpu::BufferBindingType::Uniform,
121+
has_dynamic_offset: false,
122+
min_binding_size: None,
123+
},
124+
count: None,
125+
});
126+
}
127+
128+
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
129+
entries: &bind_group_layout_entries,
130+
label: Some(&format!("{}_bind_group_layout", self.label)),
131+
});
132+
133+
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
134+
label: Some(&format!("{}_pipeline_layout", self.label)),
135+
bind_group_layouts: &[&bind_group_layout],
136+
push_constant_ranges: &[],
137+
});
138+
139+
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
140+
label: Some(self.label),
141+
layout: Some(&pipeline_layout),
142+
module: &shader,
143+
entry_point: self.entry_point,
144+
compilation_options: Default::default(),
145+
cache: None,
146+
});
147+
148+
let mut bind_group_entries = vec![
149+
wgpu::BindGroupEntry {
150+
binding: 0,
151+
resource: wgpu::BindingResource::TextureView(input_texture),
152+
},
153+
wgpu::BindGroupEntry {
154+
binding: 1,
155+
resource: wgpu::BindingResource::TextureView(output_texture),
156+
},
157+
];
158+
159+
if let Some(buffer) = &uniform_buffer {
160+
bind_group_entries.push(wgpu::BindGroupEntry {
161+
binding: 2,
162+
resource: buffer.as_entire_binding(),
163+
});
164+
}
165+
166+
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
167+
layout: &bind_group_layout,
168+
entries: &bind_group_entries,
169+
label: Some(&format!("{}_bind_group", self.label)),
170+
});
171+
172+
Ok(ComputeEffect {
173+
pipeline,
174+
bind_group,
175+
uniform_buffer,
176+
})
177+
}
178+
}

src/renderer/feature_uniform.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,3 +174,9 @@ impl FeatureUniform {
174174
}
175175
}
176176
}
177+
178+
impl FeatureUniform {
179+
pub fn gamma(&self) -> u32 {
180+
self.gamma
181+
}
182+
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// a compute shader for gamma correction
2+
3+
@group(0)
4+
@binding(0)
5+
var input_texture: texture_2d<f32>;
6+
7+
@group(0)
8+
@binding(1)
9+
var output_texture: texture_storage_2d<rgba8unorm, write>;
10+
11+
@group(0)
12+
@binding(2)
13+
var<uniform> gamma: u32;
14+
15+
@compute
16+
@workgroup_size(16, 16)
17+
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
18+
let coords = vec2<i32>(global_id.xy);
19+
let dimensions = textureDimensions(input_texture);
20+
21+
if (coords.x >= i32(dimensions.x) || coords.y >= i32(dimensions.y)) {
22+
return;
23+
}
24+
25+
var color = textureLoad(input_texture, coords, 0);
26+
27+
// If gamma is 0, just pass through (no correction)
28+
if (gamma != 0u) {
29+
let gamma_value = f32(gamma) / 100000.0;
30+
let inv_gamma = 1.0 / gamma_value;
31+
32+
color = vec4<f32>(
33+
pow(color.r, inv_gamma),
34+
pow(color.g, inv_gamma),
35+
pow(color.b, inv_gamma),
36+
color.a
37+
);
38+
}
39+
40+
textureStore(output_texture, coords, color);
41+
42+
}
43+

0 commit comments

Comments
 (0)