@@ -121,7 +121,9 @@ struct reduce<Config, BinOp, 2, device_capabilities>
121
121
for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_1; i++)
122
122
scratchAccessor.template get<scalar_t>(i*Config::SubgroupsPerVirtualWorkgroup+invocationIndex,lv1_val[i]);
123
123
lv1_val = reduction1 (lv1_val);
124
- scratchAccessor.template set<scalar_t>(invocationIndex, lv1_val[Config::ItemsPerInvocation_1-1 ]);
124
+
125
+ if (glsl::gl_SubgroupInvocationID ()==Config::SubgroupSize-1 )
126
+ scratchAccessor.template set<scalar_t>(0 , lv1_val[Config::ItemsPerInvocation_1-1 ]);
125
127
}
126
128
scratchAccessor.workgroupExecutionAndMemoryBarrier ();
127
129
@@ -130,7 +132,7 @@ struct reduce<Config, BinOp, 2, device_capabilities>
130
132
for (uint32_t idx = 0 , virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
131
133
{
132
134
scalar_t reduce_val;
133
- scratchAccessor.template get<scalar_t>(glsl:: gl_SubgroupInvocationID () ,reduce_val);
135
+ scratchAccessor.template get<scalar_t>(0 ,reduce_val);
134
136
dataAccessor.template set<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, hlsl::promote<vector_lv0_t>(reduce_val));
135
137
}
136
138
}
@@ -179,9 +181,9 @@ struct scan<Config, BinOp, Exclusive, 2, device_capabilities>
179
181
[unroll]
180
182
for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_1; i++)
181
183
scratchAccessor.template get<scalar_t>(i*Config::SubgroupsPerVirtualWorkgroup+prevIndex,lv1_val[i]);
182
- vector_lv1_t shiftedInput = hlsl::mix (hlsl::promote<vector_lv1_t>( BinOp::identity) , lv1_val, bool (invocationIndex));
183
- shiftedInput = inclusiveScan1 (shiftedInput );
184
- scratchAccessor.template set<scalar_t>(invocationIndex, shiftedInput [Config::ItemsPerInvocation_1-1 ]);
184
+ lv1_val[ 0 ] = hlsl::mix (BinOp::identity, lv1_val[ 0 ] , bool (invocationIndex));
185
+ lv1_val = inclusiveScan1 (lv1_val );
186
+ scratchAccessor.template set<scalar_t>(invocationIndex, lv1_val [Config::ItemsPerInvocation_1-1 ]);
185
187
}
186
188
scratchAccessor.workgroupExecutionAndMemoryBarrier ();
187
189
@@ -284,7 +286,7 @@ struct reduce<Config, BinOp, 3, device_capabilities>
284
286
for (uint32_t idx = 0 , virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
285
287
{
286
288
scalar_t reduce_val;
287
- scratchAccessor.template get<scalar_t>(glsl:: gl_SubgroupInvocationID () ,reduce_val);
289
+ scratchAccessor.template get<scalar_t>(0 ,reduce_val);
288
290
dataAccessor.template set<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, reduce_val);
289
291
}
290
292
}
@@ -353,8 +355,8 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
353
355
[unroll]
354
356
for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_2; i++)
355
357
scratchAccessor.template get<scalar_t>(lv1_smem_size+i*Config::SubgroupSize+prevIndex,lv2_val[i]);
356
- vector_lv2_t shiftedInput = hlsl::mix (hlsl::promote<vector_lv2_t>(BinOp::identity), lv2_val, bool (invocationIndex));
357
- shiftedInput = inclusiveScan2 (shiftedInput );
358
+ lv2_val[ 0 ] = hlsl::mix (hlsl::promote<vector_lv2_t>(BinOp::identity), lv2_val[ 0 ] , bool (invocationIndex));
359
+ vector_lv2_t shiftedScan = inclusiveScan2 (lv2_val );
358
360
359
361
// combine with level 1, only last element of each
360
362
[unroll]
@@ -363,7 +365,7 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
363
365
scalar_t last_val;
364
366
scratchAccessor.template get<scalar_t>((Config::ItemsPerInvocation_1-1 )*Config::SubgroupsPerVirtualWorkgroup+(Config::SubgroupsPerVirtualWorkgroup-1 -i),last_val);
365
367
scalar_t val = hlsl::mix (hlsl::promote<vector_lv2_t>(BinOp::identity), lv2_val, bool (i));
366
- val = binop (last_val, shiftedInput [Config::ItemsPerInvocation_2-1 ]);
368
+ val = binop (last_val, shiftedScan [Config::ItemsPerInvocation_2-1 ]);
367
369
scratchAccessor.template set<scalar_t>((Config::ItemsPerInvocation_1-1 )*Config::SubgroupsPerVirtualWorkgroup+(Config::SubgroupsPerVirtualWorkgroup-1 -i), last_val);
368
370
}
369
371
}
0 commit comments