Skip to content

Commit 350c6a3

Browse files
committed
more util funcs in config, fix some calculations
1 parent 27d84c8 commit 350c6a3

File tree

3 files changed

+50
-52
lines changed

3 files changed

+50
-52
lines changed

include/nbl/builtin/hlsl/workgroup2/arithmetic_config.hlsl

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ template<uint16_t WorkgroupSizeLog2, uint16_t SubgroupSizeLog2>
1919
struct virtual_wg_size_log2
2020
{
2121
static_assert(WorkgroupSizeLog2>=SubgroupSizeLog2, "WorkgroupSize cannot be smaller than SubgroupSize");
22-
// static_assert(WorkgroupSizeLog2<=SubgroupSizeLog2+4, "WorkgroupSize cannot be larger than SubgroupSize*16");
22+
static_assert(WorkgroupSizeLog2<=SubgroupSizeLog2*3+4, "WorkgroupSize cannot be larger than (SubgroupSize^3)*16");
2323
NBL_CONSTEXPR_STATIC_INLINE uint16_t levels = conditional_value<(WorkgroupSizeLog2>SubgroupSizeLog2),uint16_t,conditional_value<(WorkgroupSizeLog2>SubgroupSizeLog2*2+2),uint16_t,3,2>::value,1>::value;
24-
NBL_CONSTEXPR_STATIC_INLINE uint16_t value = mpl::max_v<uint32_t, WorkgroupSizeLog2-SubgroupSizeLog2, SubgroupSizeLog2>+SubgroupSizeLog2;
24+
NBL_CONSTEXPR_STATIC_INLINE uint16_t value = mpl::max_v<uint32_t, SubgroupSizeLog2*levels, WorkgroupSizeLog2>;
2525
// must have at least enough level 0 outputs to feed a single subgroup
2626
};
2727

@@ -33,24 +33,6 @@ struct items_per_invocation
3333
NBL_CONSTEXPR_STATIC_INLINE uint16_t value1 = uint16_t(0x1u) << conditional_value<VirtualWorkgroup::levels==3, uint16_t,mpl::min_v<uint16_t,ItemsPerInvocationProductLog2,2>, ItemsPerInvocationProductLog2>::value;
3434
NBL_CONSTEXPR_STATIC_INLINE uint16_t value2 = uint16_t(0x1u) << mpl::max_v<int16_t,ItemsPerInvocationProductLog2-2,0>;
3535
};
36-
37-
// explicit specializations for cases that don't fit
38-
#define SPECIALIZE_VIRTUAL_WG_SIZE_CASE(WGLOG2, SGLOG2, LEVELS, VALUE) template<>\
39-
struct virtual_wg_size_log2<WGLOG2, SGLOG2>\
40-
{\
41-
NBL_CONSTEXPR_STATIC_INLINE uint16_t levels = LEVELS;\
42-
NBL_CONSTEXPR_STATIC_INLINE uint16_t value = VALUE;\
43-
};\
44-
45-
SPECIALIZE_VIRTUAL_WG_SIZE_CASE(11,4,3,12);
46-
SPECIALIZE_VIRTUAL_WG_SIZE_CASE(7,7,1,7);
47-
SPECIALIZE_VIRTUAL_WG_SIZE_CASE(6,6,1,6);
48-
SPECIALIZE_VIRTUAL_WG_SIZE_CASE(5,5,1,5);
49-
SPECIALIZE_VIRTUAL_WG_SIZE_CASE(4,4,1,4);
50-
SPECIALIZE_VIRTUAL_WG_SIZE_CASE(3,3,1,3);
51-
SPECIALIZE_VIRTUAL_WG_SIZE_CASE(2,2,1,2);
52-
53-
#undef SPECIALIZE_VIRTUAL_WG_SIZE_CASE
5436
}
5537

5638
template<uint16_t _WorkgroupSizeLog2, uint16_t _SubgroupSizeLog2, uint16_t _ItemsPerInvocation>
@@ -71,16 +53,32 @@ struct ArithmeticConfiguration
7153
NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_2 = items_per_invoc_t::value2;
7254
static_assert(ItemsPerInvocation_1<=4, "3 level scan would have been needed with this config!");
7355

74-
NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementCount = conditional_value<LevelCount==1,uint16_t,0,conditional_value<LevelCount==3,uint16_t,SubgroupSize*ItemsPerInvocation_2,0>::value + SubgroupSize*ItemsPerInvocation_1>::value;
56+
NBL_CONSTEXPR_STATIC_INLINE uint32_t SharedScratchElementCount = conditional_value<LevelCount==1,uint16_t,
57+
0,
58+
conditional_value<LevelCount==3,uint16_t,
59+
SubgroupSize*ItemsPerInvocation_2,
60+
0
61+
>::value + SubgroupSize*ItemsPerInvocation_1
62+
>::value;
63+
64+
static bool electLast()
65+
{
66+
return glsl::gl_SubgroupInvocationID()==SubgroupSize-1;
67+
}
68+
69+
static uint32_t virtualSubgroupID(const uint32_t subgroupID, const uint32_t virtualIdx)
70+
{
71+
return virtualIdx * (WorkgroupSize >> SubgroupSizeLog2) + subgroupID;
72+
}
7573

76-
static uint32_t virtualSubgroupID(const uint32_t id, const uint32_t offset)
74+
static uint32_t sharedCoalescedIndexNextLevel(const uint32_t subgroupID, const uint32_t itemsPerInvocation)
7775
{
78-
return offset * (WorkgroupSize >> SubgroupSizeLog2) + id;
76+
return (subgroupID & (itemsPerInvocation-1)) * SubgroupSize + (subgroupID/itemsPerInvocation);
7977
}
8078

81-
static uint32_t sharedMemCoalescedIndex(const uint32_t id, const uint32_t itemsPerInvocation)
79+
static uint32_t sharedCoalescedIndexByComponent(const uint32_t invocationIndex, const uint32_t component)
8280
{
83-
return (id & (itemsPerInvocation-1)) * SubgroupSize + (id/itemsPerInvocation);
81+
return component * SubgroupSize + invocationIndex;
8482
}
8583
};
8684

include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,10 @@ struct reduce<Config, BinOp, 2, device_capabilities>
104104
vector_lv0_t scan_local;
105105
dataAccessor.template get<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local);
106106
scan_local = reduction0(scan_local);
107-
if (glsl::gl_SubgroupInvocationID()==Config::SubgroupSize-1)
107+
if (Config::electLast())
108108
{
109109
const uint32_t virtualSubgroupID = Config::virtualSubgroupID(glsl::gl_SubgroupID(), idx);
110-
const uint32_t bankedIndex = Config::sharedMemCoalescedIndex(virtualSubgroupID, Config::ItemsPerInvocation_1); // (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1);
110+
const uint32_t bankedIndex = Config::sharedCoalescedIndexNextLevel(virtualSubgroupID, Config::ItemsPerInvocation_1); // (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1);
111111
scratchAccessor.template set<scalar_t>(bankedIndex, scan_local[Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan
112112
}
113113
}
@@ -120,10 +120,10 @@ struct reduce<Config, BinOp, 2, device_capabilities>
120120
vector_lv1_t lv1_val;
121121
[unroll]
122122
for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++)
123-
scratchAccessor.template get<scalar_t>(i*Config::SubgroupSize+invocationIndex,lv1_val[i]);
123+
scratchAccessor.template get<scalar_t>(Config::sharedCoalescedIndexByComponent(invocationIndex, i),lv1_val[i]);
124124
lv1_val = reduction1(lv1_val);
125125

126-
if (glsl::gl_SubgroupInvocationID()==Config::SubgroupSize-1)
126+
if (Config::electLast())
127127
scratchAccessor.template set<scalar_t>(0, lv1_val[Config::ItemsPerInvocation_1-1]);
128128
}
129129
scratchAccessor.workgroupExecutionAndMemoryBarrier();
@@ -159,10 +159,10 @@ struct scan<Config, BinOp, Exclusive, 2, device_capabilities>
159159
dataAccessor.template get<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
160160
value = inclusiveScan0(value);
161161
dataAccessor.template set<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
162-
if (glsl::gl_SubgroupInvocationID()==Config::SubgroupSize-1)
162+
if (Config::electLast())
163163
{
164164
const uint32_t virtualSubgroupID = Config::virtualSubgroupID(glsl::gl_SubgroupID(), idx);
165-
const uint32_t bankedIndex = Config::sharedMemCoalescedIndex(virtualSubgroupID, Config::ItemsPerInvocation_1); // (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1);
165+
const uint32_t bankedIndex = Config::sharedCoalescedIndexNextLevel(virtualSubgroupID, Config::ItemsPerInvocation_1); // (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1);
166166
scratchAccessor.template set<scalar_t>(bankedIndex, value[Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan
167167
}
168168
}
@@ -176,12 +176,12 @@ struct scan<Config, BinOp, Exclusive, 2, device_capabilities>
176176
const uint32_t prevIndex = invocationIndex-1;
177177
[unroll]
178178
for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++)
179-
scratchAccessor.template get<scalar_t>(i*Config::SubgroupSize+prevIndex,lv1_val[i]);
179+
scratchAccessor.template get<scalar_t>(Config::sharedCoalescedIndexByComponent(prevIndex, i),lv1_val[i]);
180180
lv1_val[0] = hlsl::mix(BinOp::identity, lv1_val[0], bool(invocationIndex));
181181
lv1_val = inclusiveScan1(lv1_val);
182182
[unroll]
183183
for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++)
184-
scratchAccessor.template set<scalar_t>(i*Config::SubgroupSize+invocationIndex,lv1_val[i]);
184+
scratchAccessor.template set<scalar_t>(Config::sharedCoalescedIndexByComponent(invocationIndex, i),lv1_val[i]);
185185
}
186186
scratchAccessor.workgroupExecutionAndMemoryBarrier();
187187

@@ -193,7 +193,7 @@ struct scan<Config, BinOp, Exclusive, 2, device_capabilities>
193193
dataAccessor.template get<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
194194

195195
const uint32_t virtualSubgroupID = Config::virtualSubgroupID(glsl::gl_SubgroupID(), idx);
196-
const uint32_t bankedIndex = Config::sharedMemCoalescedIndex(virtualSubgroupID, Config::ItemsPerInvocation_1);
196+
const uint32_t bankedIndex = Config::sharedCoalescedIndexNextLevel(virtualSubgroupID, Config::ItemsPerInvocation_1);
197197
scalar_t left;
198198
scratchAccessor.template get<scalar_t>(bankedIndex,left);
199199
if (Exclusive)
@@ -242,10 +242,10 @@ struct reduce<Config, BinOp, 3, device_capabilities>
242242
vector_lv0_t scan_local;
243243
dataAccessor.template get<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local);
244244
scan_local = reduction0(scan_local);
245-
if (glsl::gl_SubgroupInvocationID()==Config::SubgroupSize-1)
245+
if (Config::electLast())
246246
{
247247
const uint32_t virtualSubgroupID = Config::virtualSubgroupID(glsl::gl_SubgroupID(), idx);
248-
const uint32_t bankedIndex = Config::sharedMemCoalescedIndex(virtualSubgroupID, Config::ItemsPerInvocation_1); // (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1);
248+
const uint32_t bankedIndex = Config::sharedCoalescedIndexNextLevel(virtualSubgroupID, Config::ItemsPerInvocation_1); // (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1);
249249
scratchAccessor.template set<scalar_t>(bankedIndex, scan_local[Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan
250250
}
251251
}
@@ -258,11 +258,11 @@ struct reduce<Config, BinOp, 3, device_capabilities>
258258
vector_lv1_t lv1_val;
259259
[unroll]
260260
for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++)
261-
scratchAccessor.template get<scalar_t>(i*Config::SubgroupSize+invocationIndex,lv1_val[i]);
261+
scratchAccessor.template get<scalar_t>(Config::sharedCoalescedIndexByComponent(invocationIndex, i),lv1_val[i]);
262262
lv1_val = reduction1(lv1_val);
263-
if (glsl::gl_SubgroupInvocationID()==Config::SubgroupSize-1)
263+
if (Config::electLast())
264264
{
265-
const uint32_t bankedIndex = Config::sharedMemCoalescedIndex(invocationIndex, Config::ItemsPerInvocation_2); // (invocationIndex & (Config::ItemsPerInvocation_2-1)) * Config::SubgroupsPerVirtualWorkgroup + (invocationIndex/Config::ItemsPerInvocation_2);
265+
const uint32_t bankedIndex = Config::sharedCoalescedIndexNextLevel(invocationIndex, Config::ItemsPerInvocation_2); // (invocationIndex & (Config::ItemsPerInvocation_2-1)) * Config::SubgroupsPerVirtualWorkgroup + (invocationIndex/Config::ItemsPerInvocation_2);
266266
scratchAccessor.template set<scalar_t>(bankedIndex, lv1_val[Config::ItemsPerInvocation_1-1]);
267267
}
268268
}
@@ -275,7 +275,7 @@ struct reduce<Config, BinOp, 3, device_capabilities>
275275
vector_lv2_t lv2_val;
276276
[unroll]
277277
for (uint32_t i = 0; i < Config::ItemsPerInvocation_2; i++)
278-
scratchAccessor.template get<scalar_t>(i*Config::SubgroupSize+invocationIndex,lv2_val[i]);
278+
scratchAccessor.template get<scalar_t>(Config::sharedCoalescedIndexByComponent(invocationIndex, i),lv2_val[i]);
279279
lv2_val = reduction2(lv2_val);
280280
scratchAccessor.template set<scalar_t>(invocationIndex, lv2_val[Config::ItemsPerInvocation_2-1]);
281281
}
@@ -314,10 +314,10 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
314314
dataAccessor.template get<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
315315
value = inclusiveScan0(value);
316316
dataAccessor.template set<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
317-
if (glsl::gl_SubgroupInvocationID()==Config::SubgroupSize-1)
317+
if (Config::electLast())
318318
{
319319
const uint32_t virtualSubgroupID = Config::virtualSubgroupID(glsl::gl_SubgroupID(), idx);
320-
const uint32_t bankedIndex = Config::sharedMemCoalescedIndex(virtualSubgroupID, Config::ItemsPerInvocation_1);
320+
const uint32_t bankedIndex = Config::sharedCoalescedIndexNextLevel(virtualSubgroupID, Config::ItemsPerInvocation_1);
321321
scratchAccessor.template set<scalar_t>(bankedIndex, value[Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan
322322
}
323323
}
@@ -332,15 +332,15 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
332332
const uint32_t prevIndex = invocationIndex-1;
333333
[unroll]
334334
for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++)
335-
scratchAccessor.template get<scalar_t>(i*Config::SubgroupSize+prevIndex,lv1_val[i]);
335+
scratchAccessor.template get<scalar_t>(Config::sharedCoalescedIndexByComponent(prevIndex, i),lv1_val[i]);
336336
lv1_val[0] = hlsl::mix(BinOp::identity, lv1_val[0], bool(invocationIndex));
337337
lv1_val = inclusiveScan1(lv1_val);
338338
[unroll]
339339
for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++)
340-
scratchAccessor.template set<scalar_t>(i*Config::SubgroupSize+invocationIndex,lv1_val[i]);
341-
if (glsl::gl_SubgroupInvocationID()==Config::SubgroupSize-1)
340+
scratchAccessor.template set<scalar_t>(Config::sharedCoalescedIndexByComponent(invocationIndex, i),lv1_val[i]);
341+
if (Config::electLast())
342342
{
343-
const uint32_t bankedIndex = Config::sharedMemCoalescedIndex(glsl::gl_SubgroupID(), Config::ItemsPerInvocation_2); // (glsl::gl_SubgroupID() & (Config::ItemsPerInvocation_2-1)) * Config::SubgroupSize + (glsl::gl_SubgroupID()/Config::ItemsPerInvocation_2);
343+
const uint32_t bankedIndex = Config::sharedCoalescedIndexNextLevel(glsl::gl_SubgroupID(), Config::ItemsPerInvocation_2); // (glsl::gl_SubgroupID() & (Config::ItemsPerInvocation_2-1)) * Config::SubgroupSize + (glsl::gl_SubgroupID()/Config::ItemsPerInvocation_2);
344344
scratchAccessor.template set<scalar_t>(lv1_smem_size+bankedIndex, lv1_val[Config::ItemsPerInvocation_1-1]);
345345
}
346346
}
@@ -354,12 +354,12 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
354354
const uint32_t prevIndex = invocationIndex-1;
355355
[unroll]
356356
for (uint32_t i = 0; i < Config::ItemsPerInvocation_2; i++)
357-
scratchAccessor.template get<scalar_t>(lv1_smem_size+i*Config::SubgroupSize+prevIndex,lv2_val[i]);
357+
scratchAccessor.template get<scalar_t>(lv1_smem_size+Config::sharedCoalescedIndexByComponent(prevIndex, i),lv2_val[i]);
358358
lv2_val[0] = hlsl::mix(hlsl::promote<vector_lv2_t>(BinOp::identity), lv2_val[0], bool(invocationIndex));
359359
lv2_val = inclusiveScan2(lv2_val);
360360
[unroll]
361361
for (uint32_t i = 0; i < Config::ItemsPerInvocation_2; i++)
362-
scratchAccessor.template set<scalar_t>(lv1_smem_size+i*Config::SubgroupSize+invocationIndex,lv2_val[i]);
362+
scratchAccessor.template set<scalar_t>(lv1_smem_size+Config::sharedCoalescedIndexByComponent(invocationIndex, i),lv2_val[i]);
363363
}
364364
scratchAccessor.workgroupExecutionAndMemoryBarrier();
365365

@@ -372,12 +372,12 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
372372
scratchAccessor.template get<scalar_t>(i*Config::SubgroupSize+invocationIndex,lv1_val[i]);
373373

374374
scalar_t lv2_scan;
375-
const uint32_t bankedIndex = Config::sharedMemCoalescedIndex(glsl::gl_SubgroupID(), Config::ItemsPerInvocation_2); // (glsl::gl_SubgroupID() & (Config::ItemsPerInvocation_2-1)) * Config::SubgroupSize + (glsl::gl_SubgroupID()/Config::ItemsPerInvocation_2);
375+
const uint32_t bankedIndex = Config::sharedCoalescedIndexNextLevel(glsl::gl_SubgroupID(), Config::ItemsPerInvocation_2); // (glsl::gl_SubgroupID() & (Config::ItemsPerInvocation_2-1)) * Config::SubgroupSize + (glsl::gl_SubgroupID()/Config::ItemsPerInvocation_2);
376376
scratchAccessor.template set<scalar_t>(lv1_smem_size+bankedIndex, lv2_scan);
377377

378378
[unroll]
379379
for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++)
380-
scratchAccessor.template set<scalar_t>(i*Config::SubgroupSize+invocationIndex, binop(lv1_val[i],lv2_scan));
380+
scratchAccessor.template set<scalar_t>(Config::sharedCoalescedIndexByComponent(invocationIndex, i), binop(lv1_val[i],lv2_scan));
381381
}
382382

383383
// combine with level 0
@@ -388,7 +388,7 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
388388
dataAccessor.template get<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
389389

390390
const uint32_t virtualSubgroupID = Config::virtualSubgroupID(glsl::gl_SubgroupID(), idx);
391-
const uint32_t bankedIndex = Config::sharedMemCoalescedIndex(virtualSubgroupID, Config::ItemsPerInvocation_1);
391+
const uint32_t bankedIndex = Config::sharedCoalescedIndexNextLevel(virtualSubgroupID, Config::ItemsPerInvocation_1);
392392
scalar_t left;
393393
scratchAccessor.template get<scalar_t>(bankedIndex,left);
394394
if (Exclusive)

0 commit comments

Comments
 (0)