@@ -128,14 +128,6 @@ struct reduce<Config, BinOp, 2, device_capabilities>
128
128
}
129
129
scratchAccessor.workgroupExecutionAndMemoryBarrier ();
130
130
131
- // set as last element in scan (reduction)
132
- // [unroll]
133
- // for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
134
- // {
135
- // scalar_t reduce_val;
136
- // scratchAccessor.template get<scalar_t>(0,reduce_val);
137
- // dataAccessor.template set<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, hlsl::promote<vector_lv0_t>(reduce_val));
138
- // }
139
131
scalar_t reduce_val;
140
132
scratchAccessor.template get<scalar_t>(0 ,reduce_val);
141
133
return reduce_val;
@@ -187,7 +179,9 @@ struct scan<Config, BinOp, Exclusive, 2, device_capabilities>
187
179
scratchAccessor.template get<scalar_t>(i*Config::SubgroupsPerVirtualWorkgroup+prevIndex,lv1_val[i]);
188
180
lv1_val[0 ] = hlsl::mix (BinOp::identity, lv1_val[0 ], bool (invocationIndex));
189
181
lv1_val = inclusiveScan1 (lv1_val);
190
- scratchAccessor.template set<scalar_t>(invocationIndex, lv1_val[Config::ItemsPerInvocation_1-1 ]);
182
+ [unroll]
183
+ for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_1; i++)
184
+ scratchAccessor.template set<scalar_t>(i*Config::SubgroupsPerVirtualWorkgroup+invocationIndex,lv1_val[i]);
191
185
}
192
186
scratchAccessor.workgroupExecutionAndMemoryBarrier ();
193
187
@@ -199,14 +193,16 @@ struct scan<Config, BinOp, Exclusive, 2, device_capabilities>
199
193
dataAccessor.template get<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
200
194
201
195
const uint32_t virtualSubgroupID = Config::virtualSubgroupID (glsl::gl_SubgroupID (), idx);
196
+ const uint32_t bankedIndex = Config::sharedMemCoalescedIndex (virtualSubgroupID, Config::ItemsPerInvocation_1);
202
197
scalar_t left;
203
- scratchAccessor.template get<scalar_t>(virtualSubgroupID ,left);
198
+ scratchAccessor.template get<scalar_t>(bankedIndex ,left);
204
199
if (Exclusive)
205
200
{
206
201
scalar_t left_last_elem = hlsl::mix (BinOp::identity, glsl::subgroupShuffleUp<scalar_t>(value[Config::ItemsPerInvocation_0-1 ],1 ), bool (glsl::gl_SubgroupInvocationID ()));
207
202
[unroll]
208
- for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_0; i++)
209
- value[Config::ItemsPerInvocation_0-i-1 ] = binop (left, hlsl::mix (value[Config::ItemsPerInvocation_0-i-2 ], left_last_elem, (Config::ItemsPerInvocation_0-i-1 ==0 )));
203
+ for (uint32_t i = Config::ItemsPerInvocation_0-1 ; i > 0 ; i--)
204
+ value[i] = binop (left, value[i-1 ]);
205
+ value[0 ] = binop (left, left_last_elem);
210
206
}
211
207
else
212
208
{
@@ -285,14 +281,6 @@ struct reduce<Config, BinOp, 3, device_capabilities>
285
281
}
286
282
scratchAccessor.workgroupExecutionAndMemoryBarrier ();
287
283
288
- // set as last element in scan (reduction)
289
- // [unroll]
290
- // for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
291
- // {
292
- // scalar_t reduce_val;
293
- // scratchAccessor.template get<scalar_t>(0,reduce_val);
294
- // dataAccessor.template set<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, reduce_val);
295
- // }
296
284
scalar_t reduce_val;
297
285
scratchAccessor.template get<scalar_t>(0 ,reduce_val);
298
286
return reduce_val;
0 commit comments