Skip to content

Commit 9b85882

Browse files
committed
[dx12] implement num_workgroups
1 parent 6ba30cb commit 9b85882

File tree

8 files changed

+48
-16
lines changed

8 files changed

+48
-16
lines changed

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

wgpu-core/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ thiserror = "1"
3636

3737
[dependencies.naga]
3838
git = "https://github.com/gfx-rs/naga"
39-
rev = "7613798"
39+
rev = "4e181d6"
4040
features = ["wgsl-in"]
4141

4242
[dependencies.wgt]

wgpu-hal/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,11 @@ core-graphics-types = "0.1"
6565

6666
[dependencies.naga]
6767
git = "https://github.com/gfx-rs/naga"
68-
rev = "7613798"
68+
rev = "4e181d6"
6969

7070
[dev-dependencies.naga]
7171
git = "https://github.com/gfx-rs/naga"
72-
rev = "7613798"
72+
rev = "4e181d6"
7373
features = ["wgsl-in"]
7474

7575
[dev-dependencies]

wgpu-hal/src/dx12/command.rs

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ impl super::CommandEncoder {
4646
}
4747

4848
unsafe fn prepare_draw(&mut self, base_vertex: i32, base_instance: u32) {
49-
let list = self.list.unwrap();
5049
while self.pass.dirty_vertex_buffers != 0 {
50+
let list = self.list.unwrap();
5151
let index = self.pass.dirty_vertex_buffers.trailing_zeros();
5252
self.pass.dirty_vertex_buffers ^= 1 << index;
5353
list.IASetVertexBuffers(
@@ -61,6 +61,7 @@ impl super::CommandEncoder {
6161
super::RootElement::SpecialConstantBuffer {
6262
base_vertex: other_vertex,
6363
base_instance: other_instance,
64+
other: _,
6465
} => base_vertex != other_vertex || base_instance != other_instance,
6566
_ => true,
6667
};
@@ -70,13 +71,33 @@ impl super::CommandEncoder {
7071
super::RootElement::SpecialConstantBuffer {
7172
base_vertex,
7273
base_instance,
74+
other: 0,
7375
};
7476
}
7577
}
7678
self.update_root_elements();
7779
}
7880

79-
fn prepare_dispatch(&mut self) {
81+
fn prepare_dispatch(&mut self, count: [u32; 3]) {
82+
if let Some(root_index) = self.pass.layout.special_constants_root_index {
83+
let needs_update = match self.pass.root_elements[root_index as usize] {
84+
super::RootElement::SpecialConstantBuffer {
85+
base_vertex,
86+
base_instance,
87+
other,
88+
} => [base_vertex as u32, base_instance, other] != count,
89+
_ => true,
90+
};
91+
if needs_update {
92+
self.pass.dirty_root_elements |= 1 << root_index;
93+
self.pass.root_elements[root_index as usize] =
94+
super::RootElement::SpecialConstantBuffer {
95+
base_vertex: count[0] as i32,
96+
base_instance: count[1],
97+
other: count[2],
98+
};
99+
}
100+
}
80101
self.update_root_elements();
81102
}
82103

@@ -95,12 +116,17 @@ impl super::CommandEncoder {
95116
super::RootElement::SpecialConstantBuffer {
96117
base_vertex,
97118
base_instance,
119+
other,
98120
} => match self.pass.kind {
99121
Pk::Render => {
100122
list.set_graphics_root_constant(index, base_vertex as u32, 0);
101123
list.set_graphics_root_constant(index, base_instance, 1);
102124
}
103-
Pk::Compute => (),
125+
Pk::Compute => {
126+
list.set_compute_root_constant(index, base_vertex as u32, 0);
127+
list.set_compute_root_constant(index, base_instance, 1);
128+
list.set_compute_root_constant(index, other, 2);
129+
}
104130
Pk::Transfer => (),
105131
},
106132
super::RootElement::Table(descriptor) => match self.pass.kind {
@@ -141,6 +167,7 @@ impl super::CommandEncoder {
141167
super::RootElement::SpecialConstantBuffer {
142168
base_vertex: 0,
143169
base_instance: 0,
170+
other: 0,
144171
};
145172
}
146173
self.pass.layout = layout.clone();
@@ -934,10 +961,12 @@ impl crate::CommandEncoder<super::Api> for super::CommandEncoder {
934961
}
935962

936963
unsafe fn dispatch(&mut self, count: [u32; 3]) {
937-
self.prepare_dispatch();
964+
self.prepare_dispatch(count);
938965
self.list.unwrap().dispatch(count);
939966
}
940967
unsafe fn dispatch_indirect(&mut self, buffer: &super::Buffer, offset: wgt::BufferAddress) {
968+
self.prepare_dispatch([0; 3]);
969+
//TODO: update special constants indirectly
941970
self.list.unwrap().ExecuteIndirect(
942971
self.shared.cmd_signatures.dispatch.as_mut_ptr(),
943972
1,

wgpu-hal/src/dx12/device.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -745,15 +745,15 @@ impl crate::Device<super::Api> for super::Device {
745745
);
746746
let mut parameters = Vec::new();
747747

748-
let (special_constants_root_index, special_constants_binding) = if desc
749-
.flags
750-
.contains(crate::PipelineLayoutFlags::BASE_VERTEX_INSTANCE)
751-
{
748+
let (special_constants_root_index, special_constants_binding) = if desc.flags.intersects(
749+
crate::PipelineLayoutFlags::BASE_VERTEX_INSTANCE
750+
| crate::PipelineLayoutFlags::NUM_WORK_GROUPS,
751+
) {
752752
let parameter_index = parameters.len();
753753
parameters.push(native::RootParameter::constants(
754-
native::ShaderVisibility::VS,
754+
native::ShaderVisibility::All, // really needed for VS and CS only
755755
native_binding(&bind_cbv),
756-
2, // 0 = base vertex, 1 = base instance
756+
3, // 0 = base vertex, 1 = base instance, 2 = other
757757
));
758758
let binding = bind_cbv.clone();
759759
bind_cbv.register += 1;

wgpu-hal/src/dx12/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ enum RootElement {
283283
SpecialConstantBuffer {
284284
base_vertex: i32,
285285
base_instance: u32,
286+
other: u32,
286287
},
287288
/// Descriptor table.
288289
Table(native::GpuDescriptor),

wgpu-hal/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,8 @@ bitflags!(
510510
pub struct PipelineLayoutFlags: u32 {
511511
/// Include support for base vertex/instance drawing.
512512
const BASE_VERTEX_INSTANCE = 1 << 0;
513+
/// Include support for num work groups builtin.
514+
const NUM_WORK_GROUPS = 1 << 1;
513515
}
514516
);
515517

wgpu/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,13 @@ env_logger = "0.8"
7575

7676
[dependencies.naga]
7777
git = "https://github.com/gfx-rs/naga"
78-
rev = "7613798"
78+
rev = "4e181d6"
7979
optional = true
8080

8181
# used to test all the example shaders
8282
[dev-dependencies.naga]
8383
git = "https://github.com/gfx-rs/naga"
84-
rev = "7613798"
84+
rev = "4e181d6"
8585
features = ["wgsl-in"]
8686

8787
[[example]]

0 commit comments

Comments
 (0)