Skip to content

Commit 59d75a7

Browse files
committed
Merge pull request #103613 from stuartcarnie/fix_101696
Metal: Use uniform set index passed by `RenderingDevice`
2 parents e6cb2e8 + 2b8cb36 commit 59d75a7

File tree

2 files changed

+43
-41
lines changed

2 files changed

+43
-41
lines changed

drivers/metal/metal_objects.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -787,18 +787,18 @@ struct BoundUniformSet {
787787

788788
class API_AVAILABLE(macos(11.0), ios(14.0), tvos(14.0)) MDUniformSet {
789789
private:
790-
void bind_uniforms_argument_buffers(MDShader *p_shader, MDCommandBuffer::RenderState &p_state);
791-
void bind_uniforms_direct(MDShader *p_shader, MDCommandBuffer::RenderState &p_state);
792-
void bind_uniforms_argument_buffers(MDShader *p_shader, MDCommandBuffer::ComputeState &p_state);
793-
void bind_uniforms_direct(MDShader *p_shader, MDCommandBuffer::ComputeState &p_state);
790+
void bind_uniforms_argument_buffers(MDShader *p_shader, MDCommandBuffer::RenderState &p_state, uint32_t p_set_index);
791+
void bind_uniforms_direct(MDShader *p_shader, MDCommandBuffer::RenderState &p_state, uint32_t p_set_index);
792+
void bind_uniforms_argument_buffers(MDShader *p_shader, MDCommandBuffer::ComputeState &p_state, uint32_t p_set_index);
793+
void bind_uniforms_direct(MDShader *p_shader, MDCommandBuffer::ComputeState &p_state, uint32_t p_set_index);
794794

795795
public:
796796
uint32_t index;
797797
LocalVector<RDD::BoundUniform> uniforms;
798798
HashMap<MDShader *, BoundUniformSet> bound_uniforms;
799799

800-
void bind_uniforms(MDShader *p_shader, MDCommandBuffer::RenderState &p_state);
801-
void bind_uniforms(MDShader *p_shader, MDCommandBuffer::ComputeState &p_state);
800+
void bind_uniforms(MDShader *p_shader, MDCommandBuffer::RenderState &p_state, uint32_t p_set_index);
801+
void bind_uniforms(MDShader *p_shader, MDCommandBuffer::ComputeState &p_state, uint32_t p_set_index);
802802

803803
BoundUniformSet &bound_uniform_set(MDShader *p_shader, id<MTLDevice> p_device, ResourceUsageMap &p_resource_usage);
804804
};

drivers/metal/metal_objects.mm

Lines changed: 37 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -213,36 +213,38 @@
213213
DEV_ASSERT(type == MDCommandBufferStateType::Render);
214214

215215
MDUniformSet *set = (MDUniformSet *)(p_uniform_set.id);
216-
if (render.uniform_sets.size() <= set->index) {
216+
if (render.uniform_sets.size() <= p_set_index) {
217217
uint32_t s = render.uniform_sets.size();
218-
render.uniform_sets.resize(set->index + 1);
218+
render.uniform_sets.resize(p_set_index + 1);
219219
// Set intermediate values to null.
220-
std::fill(&render.uniform_sets[s], &render.uniform_sets[set->index] + 1, nullptr);
220+
std::fill(&render.uniform_sets[s], &render.uniform_sets[p_set_index] + 1, nullptr);
221221
}
222222

223-
if (render.uniform_sets[set->index] != set) {
223+
if (render.uniform_sets[p_set_index] != set) {
224224
render.dirty.set_flag(RenderState::DIRTY_UNIFORMS);
225-
render.uniform_set_mask |= 1ULL << set->index;
226-
render.uniform_sets[set->index] = set;
225+
render.uniform_set_mask |= 1ULL << p_set_index;
226+
render.uniform_sets[p_set_index] = set;
227227
}
228228
}
229229

230230
void MDCommandBuffer::render_bind_uniform_sets(VectorView<RDD::UniformSetID> p_uniform_sets, RDD::ShaderID p_shader, uint32_t p_first_set_index, uint32_t p_set_count) {
231231
DEV_ASSERT(type == MDCommandBufferStateType::Render);
232232

233-
for (size_t i = 0u; i < p_set_count; ++i) {
233+
for (size_t i = 0; i < p_set_count; ++i) {
234234
MDUniformSet *set = (MDUniformSet *)(p_uniform_sets[i].id);
235-
if (render.uniform_sets.size() <= set->index) {
235+
236+
uint32_t index = p_first_set_index + i;
237+
if (render.uniform_sets.size() <= index) {
236238
uint32_t s = render.uniform_sets.size();
237-
render.uniform_sets.resize(set->index + 1);
239+
render.uniform_sets.resize(index + 1);
238240
// Set intermediate values to null.
239-
std::fill(&render.uniform_sets[s], &render.uniform_sets[set->index] + 1, nullptr);
241+
std::fill(&render.uniform_sets[s], &render.uniform_sets[index] + 1, nullptr);
240242
}
241243

242-
if (render.uniform_sets[set->index] != set) {
244+
if (render.uniform_sets[index] != set) {
243245
render.dirty.set_flag(RenderState::DIRTY_UNIFORMS);
244-
render.uniform_set_mask |= 1ULL << set->index;
245-
render.uniform_sets[set->index] = set;
246+
render.uniform_set_mask |= 1ULL << index;
247+
render.uniform_sets[index] = set;
246248
}
247249
}
248250
}
@@ -474,14 +476,14 @@
474476

475477
while (set_uniforms != 0) {
476478
// Find the index of the next set bit.
477-
int index = __builtin_ctzll(set_uniforms);
479+
uint32_t index = (uint32_t)__builtin_ctzll(set_uniforms);
478480
// Clear the set bit.
479481
set_uniforms &= (set_uniforms - 1);
480482
MDUniformSet *set = render.uniform_sets[index];
481-
if (set == nullptr || set->index >= (uint32_t)shader->sets.size()) {
483+
if (set == nullptr || index >= (uint32_t)shader->sets.size()) {
482484
continue;
483485
}
484-
set->bind_uniforms(shader, render);
486+
set->bind_uniforms(shader, render, index);
485487
}
486488
}
487489

@@ -955,7 +957,7 @@
955957

956958
MDShader *shader = (MDShader *)(p_shader.id);
957959
MDUniformSet *set = (MDUniformSet *)(p_uniform_set.id);
958-
set->bind_uniforms(shader, compute);
960+
set->bind_uniforms(shader, compute, p_set_index);
959961
}
960962

961963
void MDCommandBuffer::compute_bind_uniform_sets(VectorView<RDD::UniformSetID> p_uniform_sets, RDD::ShaderID p_shader, uint32_t p_first_set_index, uint32_t p_set_count) {
@@ -966,7 +968,7 @@
966968
// TODO(sgc): Bind multiple buffers using [encoder setBuffers:offsets:withRange:]
967969
for (size_t i = 0u; i < p_set_count; ++i) {
968970
MDUniformSet *set = (MDUniformSet *)(p_uniform_sets[i].id);
969-
set->bind_uniforms(shader, compute);
971+
set->bind_uniforms(shader, compute, p_first_set_index + i);
970972
}
971973
}
972974

@@ -1052,11 +1054,11 @@
10521054
}
10531055
}
10541056

1055-
void MDUniformSet::bind_uniforms_argument_buffers(MDShader *p_shader, MDCommandBuffer::RenderState &p_state) {
1057+
void MDUniformSet::bind_uniforms_argument_buffers(MDShader *p_shader, MDCommandBuffer::RenderState &p_state, uint32_t p_set_index) {
10561058
DEV_ASSERT(p_shader->uses_argument_buffers);
10571059
DEV_ASSERT(p_state.encoder != nil);
10581060

1059-
UniformSet const &set_info = p_shader->sets[index];
1061+
UniformSet const &set_info = p_shader->sets[p_set_index];
10601062

10611063
id<MTLRenderCommandEncoder> __unsafe_unretained enc = p_state.encoder;
10621064
id<MTLDevice> __unsafe_unretained device = enc.device;
@@ -1067,25 +1069,25 @@
10671069
{
10681070
uint32_t const *offset = set_info.offsets.getptr(RDD::SHADER_STAGE_VERTEX);
10691071
if (offset) {
1070-
[enc setVertexBuffer:bus.buffer offset:*offset atIndex:index];
1072+
[enc setVertexBuffer:bus.buffer offset:*offset atIndex:p_set_index];
10711073
}
10721074
}
10731075
// Set the buffer for the fragment stage.
10741076
{
10751077
uint32_t const *offset = set_info.offsets.getptr(RDD::SHADER_STAGE_FRAGMENT);
10761078
if (offset) {
1077-
[enc setFragmentBuffer:bus.buffer offset:*offset atIndex:index];
1079+
[enc setFragmentBuffer:bus.buffer offset:*offset atIndex:p_set_index];
10781080
}
10791081
}
10801082
}
10811083

1082-
void MDUniformSet::bind_uniforms_direct(MDShader *p_shader, MDCommandBuffer::RenderState &p_state) {
1084+
void MDUniformSet::bind_uniforms_direct(MDShader *p_shader, MDCommandBuffer::RenderState &p_state, uint32_t p_set_index) {
10831085
DEV_ASSERT(!p_shader->uses_argument_buffers);
10841086
DEV_ASSERT(p_state.encoder != nil);
10851087

10861088
id<MTLRenderCommandEncoder> __unsafe_unretained enc = p_state.encoder;
10871089

1088-
UniformSet const &set = p_shader->sets[index];
1090+
UniformSet const &set = p_shader->sets[p_set_index];
10891091

10901092
for (uint32_t i = 0; i < MIN(uniforms.size(), set.uniforms.size()); i++) {
10911093
RDD::BoundUniform const &uniform = uniforms[i];
@@ -1256,19 +1258,19 @@
12561258
}
12571259
}
12581260

1259-
void MDUniformSet::bind_uniforms(MDShader *p_shader, MDCommandBuffer::RenderState &p_state) {
1261+
void MDUniformSet::bind_uniforms(MDShader *p_shader, MDCommandBuffer::RenderState &p_state, uint32_t p_set_index) {
12601262
if (p_shader->uses_argument_buffers) {
1261-
bind_uniforms_argument_buffers(p_shader, p_state);
1263+
bind_uniforms_argument_buffers(p_shader, p_state, p_set_index);
12621264
} else {
1263-
bind_uniforms_direct(p_shader, p_state);
1265+
bind_uniforms_direct(p_shader, p_state, p_set_index);
12641266
}
12651267
}
12661268

1267-
void MDUniformSet::bind_uniforms_argument_buffers(MDShader *p_shader, MDCommandBuffer::ComputeState &p_state) {
1269+
void MDUniformSet::bind_uniforms_argument_buffers(MDShader *p_shader, MDCommandBuffer::ComputeState &p_state, uint32_t p_set_index) {
12681270
DEV_ASSERT(p_shader->uses_argument_buffers);
12691271
DEV_ASSERT(p_state.encoder != nil);
12701272

1271-
UniformSet const &set_info = p_shader->sets[index];
1273+
UniformSet const &set_info = p_shader->sets[p_set_index];
12721274

12731275
id<MTLComputeCommandEncoder> enc = p_state.encoder;
12741276
id<MTLDevice> device = enc.device;
@@ -1277,17 +1279,17 @@
12771279

12781280
uint32_t const *offset = set_info.offsets.getptr(RDD::SHADER_STAGE_COMPUTE);
12791281
if (offset) {
1280-
[enc setBuffer:bus.buffer offset:*offset atIndex:index];
1282+
[enc setBuffer:bus.buffer offset:*offset atIndex:p_set_index];
12811283
}
12821284
}
12831285

1284-
void MDUniformSet::bind_uniforms_direct(MDShader *p_shader, MDCommandBuffer::ComputeState &p_state) {
1286+
void MDUniformSet::bind_uniforms_direct(MDShader *p_shader, MDCommandBuffer::ComputeState &p_state, uint32_t p_set_index) {
12851287
DEV_ASSERT(!p_shader->uses_argument_buffers);
12861288
DEV_ASSERT(p_state.encoder != nil);
12871289

12881290
id<MTLComputeCommandEncoder> __unsafe_unretained enc = p_state.encoder;
12891291

1290-
UniformSet const &set = p_shader->sets[index];
1292+
UniformSet const &set = p_shader->sets[p_set_index];
12911293

12921294
for (uint32_t i = 0; i < uniforms.size(); i++) {
12931295
RDD::BoundUniform const &uniform = uniforms[i];
@@ -1407,11 +1409,11 @@
14071409
}
14081410
}
14091411

1410-
void MDUniformSet::bind_uniforms(MDShader *p_shader, MDCommandBuffer::ComputeState &p_state) {
1412+
void MDUniformSet::bind_uniforms(MDShader *p_shader, MDCommandBuffer::ComputeState &p_state, uint32_t p_set_index) {
14111413
if (p_shader->uses_argument_buffers) {
1412-
bind_uniforms_argument_buffers(p_shader, p_state);
1414+
bind_uniforms_argument_buffers(p_shader, p_state, p_set_index);
14131415
} else {
1414-
bind_uniforms_direct(p_shader, p_state);
1416+
bind_uniforms_direct(p_shader, p_state, p_set_index);
14151417
}
14161418
}
14171419

0 commit comments

Comments
 (0)