Skip to content

Commit eb44262

Browse files
committed
moved indexing functionality to config struct
1 parent 9c59677 commit eb44262

File tree

2 files changed

+22
-12
lines changed

2 files changed

+22
-12
lines changed

include/nbl/builtin/hlsl/workgroup2/arithmetic_config.hlsl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,16 @@ struct ArithmeticConfiguration
7575
static_assert(ItemsPerInvocation_1<=4, "3 level scan would have been needed with this config!");
7676

7777
NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementCount = conditional_value<LevelCount==1,uint16_t,0,conditional_value<LevelCount==3,uint16_t,SubgroupSize*ItemsPerInvocation_2,0>::value + SubgroupSize*ItemsPerInvocation_1>::value;
78+
79+
static uint32_t virtualSubgroupID(const uint32_t id, const uint32_t offset)
80+
{
81+
return offset * (WorkgroupSize >> SubgroupSizeLog2) + id;
82+
}
83+
84+
static uint32_t sharedMemCoalescedIndex(const uint32_t id, const uint32_t itemsPerInvocation)
85+
{
86+
return (id & (itemsPerInvocation-1)) * SubgroupsPerVirtualWorkgroup + (id/itemsPerInvocation);
87+
}
7888
};
7989

8090
template<class T>

include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ struct reduce<Config, BinOp, 2, device_capabilities>
105105
scan_local = reduction0(scan_local);
106106
if (glsl::gl_SubgroupInvocationID()==Config::SubgroupSize-1)
107107
{
108-
const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID();
109-
const uint32_t bankedIndex = (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1);
108+
const uint32_t virtualSubgroupID = Config::virtualSubgroupID(glsl::gl_SubgroupID(), idx);
109+
const uint32_t bankedIndex = Config::sharedMemCoalescedIndex(virtualSubgroupID, Config::ItemsPerInvocation_1); // (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1);
110110
scratchAccessor.template set<scalar_t>(bankedIndex, scan_local[Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan
111111
}
112112
}
@@ -165,8 +165,8 @@ struct scan<Config, BinOp, Exclusive, 2, device_capabilities>
165165
dataAccessor.template set<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
166166
if (glsl::gl_SubgroupInvocationID()==Config::SubgroupSize-1)
167167
{
168-
const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID();
169-
const uint32_t bankedIndex = (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1);
168+
const uint32_t virtualSubgroupID = Config::virtualSubgroupID(glsl::gl_SubgroupID(), idx);
169+
const uint32_t bankedIndex = Config::sharedMemCoalescedIndex(virtualSubgroupID, Config::ItemsPerInvocation_1); // (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1);
170170
scratchAccessor.template set<scalar_t>(bankedIndex, value[Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan
171171
}
172172
}
@@ -194,7 +194,7 @@ struct scan<Config, BinOp, Exclusive, 2, device_capabilities>
194194
vector_lv0_t value;
195195
dataAccessor.template get<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
196196

197-
const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID();
197+
const uint32_t virtualSubgroupID = Config::virtualSubgroupID(glsl::gl_SubgroupID(), idx);
198198
scalar_t left;
199199
scratchAccessor.template get<scalar_t>(virtualSubgroupID,left);
200200
if (Exclusive)
@@ -244,8 +244,8 @@ struct reduce<Config, BinOp, 3, device_capabilities>
244244
scan_local = reduction0(scan_local);
245245
if (glsl::gl_SubgroupInvocationID()==Config::SubgroupSize-1)
246246
{
247-
const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID();
248-
const uint32_t bankedIndex = (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1);
247+
const uint32_t virtualSubgroupID = Config::virtualSubgroupID(glsl::gl_SubgroupID(), idx);
248+
const uint32_t bankedIndex = Config::sharedMemCoalescedIndex(virtualSubgroupID, Config::ItemsPerInvocation_1); // (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1);
249249
scratchAccessor.template set<scalar_t>(bankedIndex, scan_local[Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan
250250
}
251251
}
@@ -262,7 +262,7 @@ struct reduce<Config, BinOp, 3, device_capabilities>
262262
lv1_val = reduction1(lv1_val);
263263
if (glsl::gl_SubgroupInvocationID()==Config::SubgroupSize-1)
264264
{
265-
const uint32_t bankedIndex = (invocationIndex & (Config::ItemsPerInvocation_2-1)) * Config::SubgroupsPerVirtualWorkgroup + (invocationIndex/Config::ItemsPerInvocation_2);
265+
const uint32_t bankedIndex = Config::sharedMemCoalescedIndex(invocationIndex, Config::ItemsPerInvocation_2); // (invocationIndex & (Config::ItemsPerInvocation_2-1)) * Config::SubgroupsPerVirtualWorkgroup + (invocationIndex/Config::ItemsPerInvocation_2);
266266
scratchAccessor.template set<scalar_t>(bankedIndex, lv1_val[Config::ItemsPerInvocation_1-1]);
267267
}
268268
}
@@ -321,8 +321,8 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
321321
dataAccessor.template set<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
322322
if (glsl::gl_SubgroupInvocationID()==Config::SubgroupSize-1)
323323
{
324-
const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID();
325-
const uint32_t bankedIndex = (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1);
324+
const uint32_t virtualSubgroupID = Config::virtualSubgroupID(glsl::gl_SubgroupID(), idx);
325+
const uint32_t bankedIndex = Config::sharedMemCoalescedIndex(virtualSubgroupID, Config::ItemsPerInvocation_1);
326326
scratchAccessor.template set<scalar_t>(bankedIndex, value[Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan
327327
}
328328
}
@@ -340,7 +340,7 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
340340
lv1_val = inclusiveScan1(lv1_val);
341341
if (glsl::gl_SubgroupInvocationID()==Config::SubgroupSize-1)
342342
{
343-
const uint32_t bankedIndex = (glsl::gl_SubgroupID() & (Config::ItemsPerInvocation_2-1)) * Config::SubgroupSize + (glsl::gl_SubgroupID()/Config::ItemsPerInvocation_2);
343+
const uint32_t bankedIndex = Config::sharedMemCoalescedIndex(glsl::gl_SubgroupID(), Config::ItemsPerInvocation_2); // (glsl::gl_SubgroupID() & (Config::ItemsPerInvocation_2-1)) * Config::SubgroupSize + (glsl::gl_SubgroupID()/Config::ItemsPerInvocation_2);
344344
scratchAccessor.template set<scalar_t>(lv1_smem_size+bankedIndex, lv1_val[Config::ItemsPerInvocation_1-1]);
345345
}
346346
}
@@ -378,7 +378,7 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
378378
vector_lv0_t value;
379379
dataAccessor.template get<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
380380

381-
const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID();
381+
const uint32_t virtualSubgroupID = Config::virtualSubgroupID(glsl::gl_SubgroupID(), idx); // idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID();
382382
const scalar_t left;
383383
scratchAccessor.template get<scalar_t>(virtualSubgroupID, left);
384384
if (Exclusive)

0 commit comments

Comments
 (0)