Skip to content

Commit 49ca655

Browse files
committed
fixes to 2-level scan indexing
1 parent 573ce44 commit 49ca655

File tree

1 file changed

+8
-20
lines changed

1 file changed

+8
-20
lines changed

include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -128,14 +128,6 @@ struct reduce<Config, BinOp, 2, device_capabilities>
128128
}
129129
scratchAccessor.workgroupExecutionAndMemoryBarrier();
130130

131-
// set as last element in scan (reduction)
132-
// [unroll]
133-
// for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
134-
// {
135-
// scalar_t reduce_val;
136-
// scratchAccessor.template get<scalar_t>(0,reduce_val);
137-
// dataAccessor.template set<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, hlsl::promote<vector_lv0_t>(reduce_val));
138-
// }
139131
scalar_t reduce_val;
140132
scratchAccessor.template get<scalar_t>(0,reduce_val);
141133
return reduce_val;
@@ -187,7 +179,9 @@ struct scan<Config, BinOp, Exclusive, 2, device_capabilities>
187179
scratchAccessor.template get<scalar_t>(i*Config::SubgroupsPerVirtualWorkgroup+prevIndex,lv1_val[i]);
188180
lv1_val[0] = hlsl::mix(BinOp::identity, lv1_val[0], bool(invocationIndex));
189181
lv1_val = inclusiveScan1(lv1_val);
190-
scratchAccessor.template set<scalar_t>(invocationIndex, lv1_val[Config::ItemsPerInvocation_1-1]);
182+
[unroll]
183+
for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++)
184+
scratchAccessor.template set<scalar_t>(i*Config::SubgroupsPerVirtualWorkgroup+invocationIndex,lv1_val[i]);
191185
}
192186
scratchAccessor.workgroupExecutionAndMemoryBarrier();
193187

@@ -199,14 +193,16 @@ struct scan<Config, BinOp, Exclusive, 2, device_capabilities>
199193
dataAccessor.template get<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
200194

201195
const uint32_t virtualSubgroupID = Config::virtualSubgroupID(glsl::gl_SubgroupID(), idx);
196+
const uint32_t bankedIndex = Config::sharedMemCoalescedIndex(virtualSubgroupID, Config::ItemsPerInvocation_1);
202197
scalar_t left;
203-
scratchAccessor.template get<scalar_t>(virtualSubgroupID,left);
198+
scratchAccessor.template get<scalar_t>(bankedIndex,left);
204199
if (Exclusive)
205200
{
206201
scalar_t left_last_elem = hlsl::mix(BinOp::identity, glsl::subgroupShuffleUp<scalar_t>(value[Config::ItemsPerInvocation_0-1],1), bool(glsl::gl_SubgroupInvocationID()));
207202
[unroll]
208-
for (uint32_t i = 0; i < Config::ItemsPerInvocation_0; i++)
209-
value[Config::ItemsPerInvocation_0-i-1] = binop(left, hlsl::mix(value[Config::ItemsPerInvocation_0-i-2], left_last_elem, (Config::ItemsPerInvocation_0-i-1==0)));
203+
for (uint32_t i = Config::ItemsPerInvocation_0-1; i > 0; i--)
204+
value[i] = binop(left, value[i-1]);
205+
value[0] = binop(left, left_last_elem);
210206
}
211207
else
212208
{
@@ -285,14 +281,6 @@ struct reduce<Config, BinOp, 3, device_capabilities>
285281
}
286282
scratchAccessor.workgroupExecutionAndMemoryBarrier();
287283

288-
// set as last element in scan (reduction)
289-
// [unroll]
290-
// for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
291-
// {
292-
// scalar_t reduce_val;
293-
// scratchAccessor.template get<scalar_t>(0,reduce_val);
294-
// dataAccessor.template set<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, reduce_val);
295-
// }
296284
scalar_t reduce_val;
297285
scratchAccessor.template get<scalar_t>(0,reduce_val);
298286
return reduce_val;

0 commit comments

Comments
 (0)