Skip to content

Commit 573ce44

Browse files
committed
reduction returns value instead of saving directly to storage
1 parent eb44262 commit 573ce44

File tree

3 files changed

+31
-22
lines changed

3 files changed

+31
-22
lines changed

include/nbl/builtin/hlsl/workgroup2/arithmetic.hlsl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,13 @@ namespace workgroup2
2222
template<class Config, class BinOp, class device_capabilities=void>
2323
struct reduction
2424
{
25-
template<class DataAccessor, class ScratchAccessor NBL_FUNC_REQUIRES(ArithmeticDataAccessor<DataAccessor> && ArithmeticSharedMemoryAccessor<ScratchAccessor>)
26-
static void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor)
25+
using scalar_t = typename BinOp::type_t;
26+
27+
template<class ReadOnlyDataAccessor, class ScratchAccessor NBL_FUNC_REQUIRES(ArithmeticDataAccessor<ReadOnlyDataAccessor> && ArithmeticSharedMemoryAccessor<ScratchAccessor>)
28+
static scalar_t __call(NBL_REF_ARG(ReadOnlyDataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor)
2729
{
2830
impl::reduce<Config,BinOp,Config::LevelCount,device_capabilities> fn;
29-
fn.template __call<DataAccessor,ScratchAccessor>(dataAccessor, scratchAccessor);
31+
return fn.template __call<ReadOnlyDataAccessor,ScratchAccessor>(dataAccessor, scratchAccessor);
3032
}
3133
};
3234

include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ struct reduce<Config, BinOp, 1, device_capabilities>
3636
// doesn't use scratch smem, need as param?
3737

3838
template<class DataAccessor, class ScratchAccessor>
39-
void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor)
39+
scalar_t __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor)
4040
{
4141
using config_t = subgroup2::Configuration<Config::SubgroupSizeLog2>;
4242
using params_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_0, device_capabilities>;
@@ -45,7 +45,8 @@ struct reduce<Config, BinOp, 1, device_capabilities>
4545
vector_t value;
4646
dataAccessor.template get<vector_t>(workgroup::SubgroupContiguousIndex(), value);
4747
value = reduction(value);
48-
dataAccessor.template set<vector_t>(workgroup::SubgroupContiguousIndex(), value);
48+
return value[0];
49+
// dataAccessor.template set<vector_t>(workgroup::SubgroupContiguousIndex(), value);
4950
}
5051
};
5152

@@ -87,7 +88,7 @@ struct reduce<Config, BinOp, 2, device_capabilities>
8788
using vector_lv1_t = vector<scalar_t, Config::ItemsPerInvocation_1>;
8889

8990
template<class DataAccessor, class ScratchAccessor>
90-
void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor)
91+
scalar_t __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor)
9192
{
9293
using config_t = subgroup2::Configuration<Config::SubgroupSizeLog2>;
9394
using params_lv0_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_0, device_capabilities>;
@@ -128,13 +129,16 @@ struct reduce<Config, BinOp, 2, device_capabilities>
128129
scratchAccessor.workgroupExecutionAndMemoryBarrier();
129130

130131
// set as last element in scan (reduction)
131-
[unroll]
132-
for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
133-
{
134-
scalar_t reduce_val;
135-
scratchAccessor.template get<scalar_t>(0,reduce_val);
136-
dataAccessor.template set<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, hlsl::promote<vector_lv0_t>(reduce_val));
137-
}
132+
// [unroll]
133+
// for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
134+
// {
135+
// scalar_t reduce_val;
136+
// scratchAccessor.template get<scalar_t>(0,reduce_val);
137+
// dataAccessor.template set<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, hlsl::promote<vector_lv0_t>(reduce_val));
138+
// }
139+
scalar_t reduce_val;
140+
scratchAccessor.template get<scalar_t>(0,reduce_val);
141+
return reduce_val;
138142
}
139143
};
140144

@@ -225,7 +229,7 @@ struct reduce<Config, BinOp, 3, device_capabilities>
225229
using vector_lv2_t = vector<scalar_t, Config::ItemsPerInvocation_2>;
226230

227231
template<class DataAccessor, class ScratchAccessor>
228-
void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor)
232+
scalar_t __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor)
229233
{
230234
using config_t = subgroup2::Configuration<Config::SubgroupSizeLog2>;
231235
using params_lv0_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_0, device_capabilities>;
@@ -282,13 +286,16 @@ struct reduce<Config, BinOp, 3, device_capabilities>
282286
scratchAccessor.workgroupExecutionAndMemoryBarrier();
283287

284288
// set as last element in scan (reduction)
285-
[unroll]
286-
for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
287-
{
288-
scalar_t reduce_val;
289-
scratchAccessor.template get<scalar_t>(0,reduce_val);
290-
dataAccessor.template set<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, reduce_val);
291-
}
289+
// [unroll]
290+
// for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
291+
// {
292+
// scalar_t reduce_val;
293+
// scratchAccessor.template get<scalar_t>(0,reduce_val);
294+
// dataAccessor.template set<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, reduce_val);
295+
// }
296+
scalar_t reduce_val;
297+
scratchAccessor.template get<scalar_t>(0,reduce_val);
298+
return reduce_val;
292299
}
293300
};
294301

0 commit comments

Comments
 (0)