Skip to content

Commit 7dfa004

Browse files
committed
shader-rt: cleanup codegen with common sym struct
1 parent 5f873b6 commit 7dfa004

File tree

1 file changed

+97
-65
lines changed

1 file changed

+97
-65
lines changed

node-graph/node-macro/src/shader_nodes/per_pixel_adjust.rs

Lines changed: 97 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -20,59 +20,90 @@ impl Parse for PerPixelAdjust {
2020

2121
impl ShaderCodegen for PerPixelAdjust {
2222
fn codegen(&self, parsed: &ParsedNodeFn, node_cfg: &TokenStream) -> syn::Result<ShaderTokens> {
23-
let (shader_entry_point, entry_point_name) = self.codegen_shader_entry_point(parsed)?;
24-
let gpu_node = self.codegen_gpu_node(parsed, node_cfg, &entry_point_name)?;
25-
Ok(ShaderTokens { shader_entry_point, gpu_node })
26-
}
27-
}
28-
29-
impl PerPixelAdjust {
30-
fn codegen_shader_entry_point(&self, parsed: &ParsedNodeFn) -> syn::Result<(TokenStream, TokenStream)> {
3123
let fn_name = &parsed.fn_name;
32-
let gpu_mod = format_ident!("{}_gpu_entry_point", fn_name);
33-
let spirv_image_ty = quote!(Image2d);
3424

25+
// categorize params and assign image bindings
3526
// bindings for images start at 1
36-
let mut binding_cnt = 0;
37-
let params = parsed
38-
.fields
39-
.iter()
40-
.map(|f| {
41-
let ident = &f.pat_ident;
42-
match &f.ty {
43-
ParsedFieldType::Node { .. } => Err(syn::Error::new_spanned(ident, "PerPixelAdjust shader nodes cannot accept other nodes as generics")),
44-
ParsedFieldType::Regular(RegularParsedField { gpu_image: false, ty, .. }) => Ok(Param {
45-
ident: Cow::Borrowed(&ident.ident),
46-
ty: Cow::Owned(ty.to_token_stream()),
47-
param_type: ParamType::Uniform,
48-
}),
49-
ParsedFieldType::Regular(RegularParsedField { gpu_image: true, .. }) => {
50-
binding_cnt += 1;
51-
Ok(Param {
52-
ident: Cow::Owned(format_ident!("image_{}", &ident.ident)),
53-
ty: Cow::Borrowed(&spirv_image_ty),
54-
param_type: ParamType::Image { binding: binding_cnt },
55-
})
27+
let params = {
28+
let mut binding_cnt = 0;
29+
parsed
30+
.fields
31+
.iter()
32+
.map(|f| {
33+
let ident = &f.pat_ident;
34+
match &f.ty {
35+
ParsedFieldType::Node { .. } => Err(syn::Error::new_spanned(ident, "PerPixelAdjust shader nodes cannot accept other nodes as generics")),
36+
ParsedFieldType::Regular(RegularParsedField { gpu_image: false, ty, .. }) => Ok(Param {
37+
ident: Cow::Borrowed(&ident.ident),
38+
ty: ty.to_token_stream(),
39+
param_type: ParamType::Uniform,
40+
}),
41+
ParsedFieldType::Regular(RegularParsedField { gpu_image: true, .. }) => {
42+
binding_cnt += 1;
43+
Ok(Param {
44+
ident: Cow::Owned(format_ident!("image_{}", &ident.ident)),
45+
ty: quote!(Image2d),
46+
param_type: ParamType::Image { binding: binding_cnt },
47+
})
48+
}
5649
}
57-
}
58-
})
59-
.collect::<syn::Result<Vec<_>>>()?;
50+
})
51+
.collect::<syn::Result<Vec<_>>>()?
52+
};
53+
54+
let entry_point_mod = format_ident!("{}_gpu_entry_point", fn_name);
55+
let entry_point_name_ident = format_ident!("ENTRY_POINT_NAME");
56+
let entry_point_name = quote!(#entry_point_mod::#entry_point_name_ident);
57+
let gpu_node_mod = format_ident!("{}_gpu", fn_name);
58+
59+
let codegen = PerPixelAdjustCodegen {
60+
parsed,
61+
node_cfg,
62+
params,
63+
entry_point_mod,
64+
entry_point_name_ident,
65+
entry_point_name,
66+
gpu_node_mod,
67+
};
6068

61-
let uniform_members = params
69+
Ok(ShaderTokens {
70+
shader_entry_point: codegen.codegen_shader_entry_point()?,
71+
gpu_node: codegen.codegen_gpu_node()?,
72+
})
73+
}
74+
}
75+
76+
pub struct PerPixelAdjustCodegen<'a> {
77+
parsed: &'a ParsedNodeFn,
78+
node_cfg: &'a TokenStream,
79+
params: Vec<Param<'a>>,
80+
entry_point_mod: Ident,
81+
entry_point_name_ident: Ident,
82+
entry_point_name: TokenStream,
83+
gpu_node_mod: Ident,
84+
}
85+
86+
impl PerPixelAdjustCodegen<'_> {
87+
fn codegen_shader_entry_point(&self) -> syn::Result<TokenStream> {
88+
let fn_name = &self.parsed.fn_name;
89+
let uniform_members = self
90+
.params
6291
.iter()
6392
.filter_map(|Param { ident, ty, param_type }| match param_type {
6493
ParamType::Image { .. } => None,
6594
ParamType::Uniform => Some(quote! {#ident: #ty}),
6695
})
6796
.collect::<Vec<_>>();
68-
let image_params = params
97+
let image_params = self
98+
.params
6999
.iter()
70100
.filter_map(|Param { ident, ty, param_type }| match param_type {
71101
ParamType::Image { binding } => Some(quote! {#[spirv(descriptor_set = 0, binding = #binding)] #ident: &#ty}),
72102
ParamType::Uniform => None,
73103
})
74104
.collect::<Vec<_>>();
75-
let call_args = params
105+
let call_args = self
106+
.params
76107
.iter()
77108
.map(|Param { ident, param_type, .. }| match param_type {
78109
ParamType::Image { .. } => quote!(Color::from_vec4(#ident.fetch_with(texel_coord, lod(0)))),
@@ -81,11 +112,10 @@ impl PerPixelAdjust {
81112
.collect::<Vec<_>>();
82113
let context = quote!(());
83114

84-
let entry_point_name = format_ident!("ENTRY_POINT_NAME");
85-
let entry_point_sym = quote!(#gpu_mod::#entry_point_name);
86-
87-
let shader_entry_point = quote! {
88-
pub mod #gpu_mod {
115+
let entry_point_mod = &self.entry_point_mod;
116+
let entry_point_name = &self.entry_point_name_ident;
117+
Ok(quote! {
118+
pub mod #entry_point_mod {
89119
use super::*;
90120
use graphene_core_shaders::color::Color;
91121
use spirv_std::spirv;
@@ -111,23 +141,19 @@ impl PerPixelAdjust {
111141
*color_out = color.to_vec4();
112142
}
113143
}
114-
};
115-
Ok((shader_entry_point, entry_point_sym))
144+
})
116145
}
117146

118-
fn codegen_gpu_node(&self, parsed: &ParsedNodeFn, node_cfg: &TokenStream, entry_point_name: &TokenStream) -> syn::Result<TokenStream> {
119-
let fn_name = format_ident!("{}_gpu", parsed.fn_name);
120-
let struct_name = format_ident!("{}", fn_name.to_string().to_case(Case::Pascal));
121-
let mod_name = fn_name.clone();
122-
123-
let gcore = match &parsed.crate_name {
147+
fn codegen_gpu_node(&self) -> syn::Result<TokenStream> {
148+
let gcore = match &self.parsed.crate_name {
124149
FoundCrate::Itself => format_ident!("crate"),
125150
FoundCrate::Name(name) => format_ident!("{name}"),
126151
};
127-
let raster_gpu: Type = parse_quote!(#gcore::table::Table<#gcore::raster_types::Raster<#gcore::raster_types::GPU>>);
128152

129153
// adapt fields for gpu node
130-
let mut fields = parsed
154+
let raster_gpu: Type = parse_quote!(#gcore::table::Table<#gcore::raster_types::Raster<#gcore::raster_types::GPU>>);
155+
let mut fields = self
156+
.parsed
131157
.fields
132158
.iter()
133159
.map(|f| match &f.ty {
@@ -144,7 +170,7 @@ impl PerPixelAdjust {
144170
})
145171
.collect::<syn::Result<Vec<_>>>()?;
146172

147-
// wgpu_executor field
173+
// insert wgpu_executor field
148174
let wgpu_executor = format_ident!("__wgpu_executor");
149175
fields.push(ParsedField {
150176
pat_ident: PatIdent {
@@ -174,17 +200,19 @@ impl PerPixelAdjust {
174200
unit: None,
175201
});
176202

177-
// exactly one gpu_image field, may be expanded later
203+
// find exactly one gpu_image field, runtime doesn't support more than 1 atm
178204
let gpu_image_field = {
179205
let mut iter = fields.iter().filter(|f| matches!(f.ty, ParsedFieldType::Regular(RegularParsedField { gpu_image: true, .. })));
180206
match (iter.next(), iter.next()) {
181207
(Some(v), None) => Ok(v),
182208
(Some(_), Some(more)) => Err(syn::Error::new_spanned(&more.pat_ident, "No more than one parameter must be annotated with `#[gpu_image]`")),
183-
(None, _) => Err(syn::Error::new_spanned(&parsed.fn_name, "At least one parameter must be annotated with `#[gpu_image]`")),
209+
(None, _) => Err(syn::Error::new_spanned(&self.parsed.fn_name, "At least one parameter must be annotated with `#[gpu_image]`")),
184210
}?
185211
};
186212
let gpu_image = &gpu_image_field.pat_ident.ident;
187213

214+
// node function body
215+
let entry_point_name = &self.entry_point_name;
188216
let body = quote! {
189217
{
190218
#wgpu_executor.shader_runtime.run_per_pixel_adjust(&::wgpu_executor::shader_runtime::Shaders {
@@ -194,47 +222,51 @@ impl PerPixelAdjust {
194222
}
195223
};
196224

225+
// call node codegen
197226
let mut parsed_node_fn = ParsedNodeFn {
198-
vis: parsed.vis.clone(),
227+
vis: self.parsed.vis.clone(),
199228
attributes: NodeFnAttributes {
200229
shader_node: Some(ShaderNodeType::GpuNode),
201-
..parsed.attributes.clone()
230+
..self.parsed.attributes.clone()
202231
},
203-
fn_name,
204-
struct_name,
205-
mod_name: mod_name.clone(),
232+
fn_name: self.gpu_node_mod.clone(),
233+
struct_name: format_ident!("{}", self.gpu_node_mod.to_string().to_case(Case::Pascal)),
234+
mod_name: self.gpu_node_mod.clone(),
206235
fn_generics: vec![parse_quote!('a: 'n)],
207236
where_clause: None,
208237
input: Input {
209-
pat_ident: parsed.input.pat_ident.clone(),
238+
pat_ident: self.parsed.input.pat_ident.clone(),
210239
ty: parse_quote!(impl #gcore::context::Ctx),
211240
implementations: Default::default(),
212241
},
213242
output_type: raster_gpu,
214243
is_async: true,
215244
fields,
216245
body,
217-
crate_name: parsed.crate_name.clone(),
246+
crate_name: self.parsed.crate_name.clone(),
218247
description: "".to_string(),
219248
};
220249
parsed_node_fn.replace_impl_trait_in_input();
221-
let gpu_node = crate::codegen::generate_node_code(&parsed_node_fn)?;
250+
let gpu_node_impl = crate::codegen::generate_node_code(&parsed_node_fn)?;
222251

252+
// wrap node in `mod #gpu_node_mod`
253+
let node_cfg = self.node_cfg;
254+
let gpu_node_mod = &self.gpu_node_mod;
223255
Ok(quote! {
224256
#node_cfg
225-
mod #mod_name {
257+
mod #gpu_node_mod {
226258
use super::*;
227259
use wgpu_executor::WgpuExecutor;
228260

229-
#gpu_node
261+
#gpu_node_impl
230262
}
231263
})
232264
}
233265
}
234266

235267
struct Param<'a> {
236268
ident: Cow<'a, Ident>,
237-
ty: Cow<'a, TokenStream>,
269+
ty: TokenStream,
238270
param_type: ParamType,
239271
}
240272

0 commit comments

Comments
 (0)