Skip to content

Commit 27ee35c

Browse files
committed
Add ReferenceType trait for command structures
Use it to deduplicate some `Foo` / `ArcFoo` / `TraceFoo` definitions.
1 parent 65024e6 commit 27ee35c

File tree

12 files changed

+196
-115
lines changed

12 files changed

+196
-115
lines changed

Cargo.lock

Lines changed: 17 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ libloading = "0.8"
142142
libm = { version = "0.2.6", default-features = false }
143143
libtest-mimic = "0.8"
144144
log = "0.4.21"
145+
macro_rules_attribute = "0.2"
145146
nanoserde = "0.2"
146147
nanorand = { version = "0.8", default-features = false, features = ["wyrand"] }
147148
noise = "0.9"

player/src/bin/play.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ fn main() {
77

88
use player::GlobalPlay as _;
99
use wgc::device::trace;
10-
use wgpu_core::identity::IdentityManager;
10+
use wgpu_core::{command::IdReferences, identity::IdentityManager};
1111

1212
use std::{
1313
fs,
@@ -52,7 +52,7 @@ fn main() {
5252

5353
log::info!("Loading trace '{trace:?}'");
5454
let file = fs::File::open(trace).unwrap();
55-
let mut actions: Vec<trace::Action> = ron::de::from_reader(file).unwrap();
55+
let mut actions: Vec<trace::Action<IdReferences>> = ron::de::from_reader(file).unwrap();
5656
actions.reverse(); // allows us to pop from the top
5757
log::info!("Found {} actions", actions.len());
5858

player/src/lib.rs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66
extern crate wgpu_core as wgc;
77
extern crate wgpu_types as wgt;
88

9-
use wgc::{command::Command, device::trace, identity::IdentityManager};
9+
use wgc::{
10+
command::{Command, IdReferences},
11+
device::trace,
12+
identity::IdentityManager,
13+
};
1014

1115
use std::{borrow::Cow, fs, path::Path};
1216

@@ -21,7 +25,7 @@ pub trait GlobalPlay {
2125
&self,
2226
device: wgc::id::DeviceId,
2327
queue: wgc::id::QueueId,
24-
action: trace::Action,
28+
action: trace::Action<IdReferences>,
2529
dir: &Path,
2630
command_encoder_id_manager: &mut IdentityManager<wgc::id::markers::CommandEncoder>,
2731
command_buffer_id_manager: &mut IdentityManager<wgc::id::markers::CommandBuffer>,
@@ -143,7 +147,7 @@ impl GlobalPlay for wgc::global::Global {
143147
}
144148
};
145149
wgc::ray_tracing::BlasBuildEntry {
146-
blas_id: x.blas_id,
150+
blas_id: x.blas,
147151
geometries,
148152
}
149153
});
@@ -153,14 +157,14 @@ impl GlobalPlay for wgc::global::Global {
153157
instance
154158
.as_ref()
155159
.map(|instance| wgc::ray_tracing::TlasInstance {
156-
blas_id: instance.blas_id,
160+
blas_id: instance.blas,
157161
transform: &instance.transform,
158162
custom_data: instance.custom_data,
159163
mask: instance.mask,
160164
})
161165
});
162166
wgc::ray_tracing::TlasPackage {
163-
tlas_id: x.tlas_id,
167+
tlas_id: x.tlas,
164168
instances: Box::new(instances),
165169
lowest_unmodified: x.lowest_unmodified,
166170
}
@@ -188,7 +192,7 @@ impl GlobalPlay for wgc::global::Global {
188192
&self,
189193
device: wgc::id::DeviceId,
190194
queue: wgc::id::QueueId,
191-
action: trace::Action,
195+
action: trace::Action<IdReferences>,
192196
dir: &Path,
193197
command_encoder_id_manager: &mut IdentityManager<wgc::id::markers::CommandEncoder>,
194198
command_buffer_id_manager: &mut IdentityManager<wgc::id::markers::CommandBuffer>,

player/tests/player/main.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use std::{
2020
path::{Path, PathBuf},
2121
slice,
2222
};
23-
use wgc::identity::IdentityManager;
23+
use wgc::{command::IdReferences, identity::IdentityManager};
2424

2525
#[derive(serde::Deserialize)]
2626
struct RawId {
@@ -57,7 +57,7 @@ struct Expectation {
5757
struct Test<'a> {
5858
features: wgt::Features,
5959
expectations: Vec<Expectation>,
60-
actions: Vec<wgc::device::trace::Action<'a>>,
60+
actions: Vec<wgc::device::trace::Action<'a, IdReferences>>,
6161
}
6262

6363
fn map_callback(status: Result<(), wgc::resource::BufferAccessError>) {

wgpu-core/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ document-features.workspace = true
178178
hashbrown.workspace = true
179179
indexmap.workspace = true
180180
log.workspace = true
181+
macro_rules_attribute.workspace = true
181182
once_cell = { workspace = true, features = ["std"] }
182183
parking_lot.workspace = true
183184
profiling = { workspace = true, default-features = false }

wgpu-core/src/command/encoder_command.rs

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,77 @@
11
use core::convert::Infallible;
22

33
use alloc::{string::String, sync::Arc, vec::Vec};
4+
#[cfg(feature = "serde")]
5+
use macro_rules_attribute::attribute_alias;
46

57
use crate::{
68
id,
79
resource::{Buffer, QuerySet, Texture},
810
};
911

12+
pub trait ReferenceType {
13+
type Buffer: Clone + core::fmt::Debug;
14+
type Texture: Clone + core::fmt::Debug;
15+
type TextureView: Clone + core::fmt::Debug;
16+
type QuerySet: Clone + core::fmt::Debug;
17+
type BindGroup: Clone + core::fmt::Debug;
18+
type RenderPipeline: Clone + core::fmt::Debug;
19+
type RenderBundle: Clone + core::fmt::Debug;
20+
type ComputePipeline: Clone + core::fmt::Debug;
21+
type Blas: Clone + core::fmt::Debug;
22+
type Tlas: Clone + core::fmt::Debug;
23+
}
24+
25+
#[derive(Clone, Debug)]
26+
pub struct IdReferences;
27+
28+
#[derive(Clone, Debug)]
29+
pub struct ArcReferences;
30+
31+
impl ReferenceType for IdReferences {
32+
type Buffer = id::BufferId;
33+
type Texture = id::TextureId;
34+
type TextureView = id::TextureViewId;
35+
type QuerySet = id::QuerySetId;
36+
type BindGroup = id::BindGroupId;
37+
type RenderPipeline = id::RenderPipelineId;
38+
type RenderBundle = id::RenderBundleId;
39+
type ComputePipeline = id::ComputePipelineId;
40+
type Blas = id::BlasId;
41+
type Tlas = id::TlasId;
42+
}
43+
44+
impl ReferenceType for ArcReferences {
45+
type Buffer = Arc<Buffer>;
46+
type Texture = Arc<Texture>;
47+
type TextureView = Arc<crate::resource::TextureView>;
48+
type QuerySet = Arc<QuerySet>;
49+
type BindGroup = Arc<crate::binding_model::BindGroup>;
50+
type RenderPipeline = Arc<crate::pipeline::RenderPipeline>;
51+
type RenderBundle = Arc<crate::command::RenderBundle>;
52+
type ComputePipeline = Arc<crate::pipeline::ComputePipeline>;
53+
type Blas = Arc<crate::resource::Blas>;
54+
type Tlas = Arc<crate::resource::Tlas>;
55+
}
56+
57+
#[cfg(feature = "serde")]
58+
attribute_alias! {
59+
#[apply(serde_object_reference_struct)] =
60+
#[derive(serde::Serialize, serde::Deserialize)]
61+
#[serde(bound =
62+
"R::Buffer: serde::Serialize + for<'d> serde::Deserialize<'d>,\
63+
R::Texture: serde::Serialize + for<'d> serde::Deserialize<'d>,\
64+
R::TextureView: serde::Serialize + for<'d> serde::Deserialize<'d>,\
65+
R::QuerySet: serde::Serialize + for<'d> serde::Deserialize<'d>,\
66+
R::BindGroup: serde::Serialize + for<'d> serde::Deserialize<'d>,\
67+
R::RenderPipeline: serde::Serialize + for<'d> serde::Deserialize<'d>,\
68+
R::RenderBundle: serde::Serialize + for<'d> serde::Deserialize<'d>,\
69+
R::ComputePipeline: serde::Serialize + for<'d> serde::Deserialize<'d>,\
70+
R::Blas: serde::Serialize + for<'d> serde::Deserialize<'d>,\
71+
R::Tlas: serde::Serialize + for<'d> serde::Deserialize<'d>"
72+
)];
73+
}
74+
1075
#[derive(Clone, Debug)]
1176
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
1277
pub enum Command {

wgpu-core/src/command/mod.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,15 @@ use core::mem::{self, ManuallyDrop};
3434
use core::ops;
3535

3636
pub(crate) use self::clear::clear_texture;
37+
#[cfg(feature = "serde")]
38+
pub(crate) use self::encoder_command::serde_object_reference_struct;
3739
pub use self::{
3840
bundle::*,
3941
clear::ClearError,
4042
compute::*,
4143
compute_command::{ArcComputeCommand, ComputeCommand},
4244
draw::*,
43-
encoder_command::{ArcCommand, Command},
45+
encoder_command::{ArcCommand, ArcReferences, Command, IdReferences, ReferenceType},
4446
query::*,
4547
render::*,
4648
render_command::{ArcRenderCommand, RenderCommand},

wgpu-core/src/command/ray_tracing.rs

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,32 +8,29 @@ use core::{
88
use wgt::{math::align_to, BufferUsages, BufferUses, Features};
99

1010
use crate::{
11-
command::CommandBufferMutable,
11+
command::encoder::EncodingState,
12+
ray_tracing::{AsAction, AsBuild, BlasTriangleGeometryInfo, TlasBuild, ValidateAsActionsError},
13+
resource::InvalidResourceError,
14+
track::Tracker,
15+
};
16+
use crate::{command::EncoderStateError, device::resource::CommandIndices};
17+
use crate::{
18+
command::{ArcCommand, ArcReferences, CommandBufferMutable},
1219
device::queue::TempResource,
1320
global::Global,
1421
id::CommandEncoderId,
1522
init_tracker::MemoryInitKind,
1623
ray_tracing::{
17-
BlasBuildEntry, BlasGeometries, BuildAccelerationStructureError, TlasPackage,
18-
TraceBlasBuildEntry, TraceBlasGeometries, TraceBlasTriangleGeometry, TraceTlasInstance,
19-
TraceTlasPackage,
24+
ArcBlasBuildEntry, ArcBlasGeometries, ArcBlasTriangleGeometry, ArcTlasInstance,
25+
ArcTlasPackage, BlasBuildEntry, BlasGeometries, BuildAccelerationStructureError,
26+
OwnedBlasBuildEntry, OwnedTlasPackage, TlasPackage, TraceBlasBuildEntry,
27+
TraceBlasGeometries, TraceBlasTriangleGeometry, TraceTlasInstance, TraceTlasPackage,
2028
},
2129
resource::{Blas, BlasCompactState, Buffer, Labeled, StagingBuffer, Tlas},
2230
scratch::ScratchBuffer,
2331
snatch::SnatchGuard,
2432
track::PendingTransition,
2533
};
26-
use crate::{command::EncoderStateError, device::resource::CommandIndices};
27-
use crate::{
28-
command::{encoder::EncodingState, ArcCommand},
29-
ray_tracing::{
30-
ArcBlasBuildEntry, ArcBlasGeometries, ArcBlasTriangleGeometry, ArcTlasInstance,
31-
ArcTlasPackage, AsAction, AsBuild, BlasTriangleGeometryInfo, TlasBuild,
32-
ValidateAsActionsError,
33-
},
34-
resource::InvalidResourceError,
35-
track::Tracker,
36-
};
3734
use crate::{lock::RwLockWriteGuard, resource::RawResourceAccess};
3835

3936
use crate::id::{BlasId, TlasId};
@@ -146,7 +143,7 @@ impl Global {
146143
}
147144
};
148145
TraceBlasBuildEntry {
149-
blas_id: blas_entry.blas_id,
146+
blas: blas_entry.blas_id,
150147
geometries,
151148
}
152149
})
@@ -158,15 +155,15 @@ impl Global {
158155
.instances
159156
.map(|instance| {
160157
instance.map(|instance| TraceTlasInstance {
161-
blas_id: instance.blas_id,
158+
blas: instance.blas_id,
162159
transform: *instance.transform,
163160
custom_data: instance.custom_data,
164161
mask: instance.mask,
165162
})
166163
})
167164
.collect();
168165
TraceTlasPackage {
169-
tlas_id: package.tlas_id,
166+
tlas: package.tlas_id,
170167
instances,
171168
lowest_unmodified: package.lowest_unmodified,
172169
}
@@ -214,7 +211,7 @@ impl Global {
214211
}
215212
};
216213
Ok(ArcBlasBuildEntry {
217-
blas: self.resolve_blas_id(blas_entry.blas_id)?,
214+
blas: self.resolve_blas_id(blas_entry.blas)?,
218215
geometries,
219216
})
220217
})
@@ -231,7 +228,7 @@ impl Global {
231228
.as_ref()
232229
.map(|instance| {
233230
Ok(ArcTlasInstance {
234-
blas: self.resolve_blas_id(instance.blas_id)?,
231+
blas: self.resolve_blas_id(instance.blas)?,
235232
transform: instance.transform,
236233
custom_data: instance.custom_data,
237234
mask: instance.mask,
@@ -241,7 +238,7 @@ impl Global {
241238
})
242239
.collect::<Result<_, BuildAccelerationStructureError>>()?;
243240
Ok(ArcTlasPackage {
244-
tlas: self.resolve_tlas_id(tlas_package.tlas_id)?,
241+
tlas: self.resolve_tlas_id(tlas_package.tlas)?,
245242
instances,
246243
lowest_unmodified: tlas_package.lowest_unmodified,
247244
})
@@ -255,8 +252,8 @@ impl Global {
255252

256253
pub(crate) fn build_acceleration_structures(
257254
state: &mut EncodingState,
258-
blas: Vec<ArcBlasBuildEntry>,
259-
tlas: Vec<ArcTlasPackage>,
255+
blas: Vec<OwnedBlasBuildEntry<ArcReferences>>,
256+
tlas: Vec<OwnedTlasPackage<ArcReferences>>,
260257
) -> Result<(), BuildAccelerationStructureError> {
261258
state
262259
.device
@@ -281,7 +278,7 @@ pub(crate) fn build_acceleration_structures(
281278
&mut scratch_buffer_blas_size,
282279
&mut blas_storage,
283280
)?;
284-
let mut tlas_lock_store = Vec::<(Option<ArcTlasPackage>, Arc<Tlas>)>::new();
281+
let mut tlas_lock_store = Vec::<(Option<OwnedTlasPackage<ArcReferences>>, Arc<Tlas>)>::new();
285282

286283
for package in tlas.into_iter() {
287284
let tlas = package.tlas.clone();
@@ -614,7 +611,7 @@ impl CommandBufferMutable {
614611

615612
///iterates over the blas iterator, and it's geometry, pushing the buffers into a storage vector (and also some validation).
616613
fn iter_blas(
617-
blas_iter: impl Iterator<Item = ArcBlasBuildEntry>,
614+
blas_iter: impl Iterator<Item = OwnedBlasBuildEntry<ArcReferences>>,
618615
tracker: &mut Tracker,
619616
build_command: &mut AsBuild,
620617
buf_storage: &mut Vec<TriangleBufferStore>,

0 commit comments

Comments
 (0)