@@ -36,7 +36,7 @@ struct reduce<Config, BinOp, 1, device_capabilities>
36
36
// doesn't use scratch smem, need as param?
37
37
38
38
template<class DataAccessor, class ScratchAccessor>
39
- void __call (NBL_REF_ARG (DataAccessor) dataAccessor, NBL_REF_ARG (ScratchAccessor) scratchAccessor)
39
+ scalar_t __call (NBL_REF_ARG (DataAccessor) dataAccessor, NBL_REF_ARG (ScratchAccessor) scratchAccessor)
40
40
{
41
41
using config_t = subgroup2::Configuration<Config::SubgroupSizeLog2>;
42
42
using params_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_0, device_capabilities>;
@@ -45,7 +45,8 @@ struct reduce<Config, BinOp, 1, device_capabilities>
45
45
vector_t value;
46
46
dataAccessor.template get<vector_t>(workgroup::SubgroupContiguousIndex (), value);
47
47
value = reduction (value);
48
- dataAccessor.template set<vector_t>(workgroup::SubgroupContiguousIndex (), value);
48
+ return value[0 ];
49
+ // dataAccessor.template set<vector_t>(workgroup::SubgroupContiguousIndex(), value);
49
50
}
50
51
};
51
52
@@ -87,7 +88,7 @@ struct reduce<Config, BinOp, 2, device_capabilities>
87
88
using vector_lv1_t = vector <scalar_t, Config::ItemsPerInvocation_1>;
88
89
89
90
template<class DataAccessor, class ScratchAccessor>
90
- void __call (NBL_REF_ARG (DataAccessor) dataAccessor, NBL_REF_ARG (ScratchAccessor) scratchAccessor)
91
+ scalar_t __call (NBL_REF_ARG (DataAccessor) dataAccessor, NBL_REF_ARG (ScratchAccessor) scratchAccessor)
91
92
{
92
93
using config_t = subgroup2::Configuration<Config::SubgroupSizeLog2>;
93
94
using params_lv0_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_0, device_capabilities>;
@@ -128,13 +129,16 @@ struct reduce<Config, BinOp, 2, device_capabilities>
128
129
scratchAccessor.workgroupExecutionAndMemoryBarrier ();
129
130
130
131
// set as last element in scan (reduction)
131
- [unroll]
132
- for (uint32_t idx = 0 , virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
133
- {
134
- scalar_t reduce_val;
135
- scratchAccessor.template get<scalar_t>(0 ,reduce_val);
136
- dataAccessor.template set<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, hlsl::promote<vector_lv0_t>(reduce_val));
137
- }
132
+ // [unroll]
133
+ // for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
134
+ // {
135
+ // scalar_t reduce_val;
136
+ // scratchAccessor.template get<scalar_t>(0,reduce_val);
137
+ // dataAccessor.template set<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, hlsl::promote<vector_lv0_t>(reduce_val));
138
+ // }
139
+ scalar_t reduce_val;
140
+ scratchAccessor.template get<scalar_t>(0 ,reduce_val);
141
+ return reduce_val;
138
142
}
139
143
};
140
144
@@ -225,7 +229,7 @@ struct reduce<Config, BinOp, 3, device_capabilities>
225
229
using vector_lv2_t = vector <scalar_t, Config::ItemsPerInvocation_2>;
226
230
227
231
template<class DataAccessor, class ScratchAccessor>
228
- void __call (NBL_REF_ARG (DataAccessor) dataAccessor, NBL_REF_ARG (ScratchAccessor) scratchAccessor)
232
+ scalar_t __call (NBL_REF_ARG (DataAccessor) dataAccessor, NBL_REF_ARG (ScratchAccessor) scratchAccessor)
229
233
{
230
234
using config_t = subgroup2::Configuration<Config::SubgroupSizeLog2>;
231
235
using params_lv0_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_0, device_capabilities>;
@@ -282,13 +286,16 @@ struct reduce<Config, BinOp, 3, device_capabilities>
282
286
scratchAccessor.workgroupExecutionAndMemoryBarrier ();
283
287
284
288
// set as last element in scan (reduction)
285
- [unroll]
286
- for (uint32_t idx = 0 , virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
287
- {
288
- scalar_t reduce_val;
289
- scratchAccessor.template get<scalar_t>(0 ,reduce_val);
290
- dataAccessor.template set<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, reduce_val);
291
- }
289
+ // [unroll]
290
+ // for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
291
+ // {
292
+ // scalar_t reduce_val;
293
+ // scratchAccessor.template get<scalar_t>(0,reduce_val);
294
+ // dataAccessor.template set<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, reduce_val);
295
+ // }
296
+ scalar_t reduce_val;
297
+ scratchAccessor.template get<scalar_t>(0 ,reduce_val);
298
+ return reduce_val;
292
299
}
293
300
};
294
301
0 commit comments