@@ -84,64 +84,60 @@ struct max
84
84
85
85
86
86
// subgroup method emulations on the CPU, to verify the results of the GPU methods
87
- template <typename T>
87
+ template <class CRTP , typename T>
88
88
struct emulatedSubgroupCommon
89
89
{
90
- inline const T* getSubgroupData (uint32_t & subgroupInvocationID, uint32_t & pseudoSubgroupID, const T* workgroupData, const uint32_t localInvocationIndex, uint32_t subgroupSize, uint32_t workgroupSize)
90
+ using type_t = T;
91
+
92
+ inline void operator ()(type_t * outputData, const type_t * workgroupData, uint32_t workgroupSize, uint32_t subgroupSize)
91
93
{
92
- pseudoSubgroupID = localInvocationIndex&(-subgroupSize);
93
- auto subgroupData = workgroupData+pseudoSubgroupID;
94
- subgroupInvocationID = localInvocationIndex-pseudoSubgroupID;
95
- return workgroupData+pseudoSubgroupID;
94
+ for (uint32_t pseudoSubgroupID=0u ; pseudoSubgroupID<workgroupSize; pseudoSubgroupID+=subgroupSize)
95
+ {
96
+ type_t * outSubgroupData = outputData+pseudoSubgroupID;
97
+ const type_t * subgroupData = workgroupData+pseudoSubgroupID;
98
+ CRTP::impl (outSubgroupData,subgroupData,core::min<uint32_t >(subgroupSize,workgroupSize-pseudoSubgroupID));
99
+ }
96
100
}
97
101
};
98
102
template <class OP >
99
- struct emulatedSubgroupReduction : emulatedSubgroupCommon<typename OP::type_t >
103
+ struct emulatedSubgroupReduction : emulatedSubgroupCommon<emulatedSubgroupReduction<OP>, typename OP::type_t >
100
104
{
101
105
using type_t = typename OP::type_t ;
102
106
103
- inline type_t operator ()( const type_t * workgroupData , const uint32_t localInvocationIndex, uint32_t subgroupSize, uint32_t workgroupSize )
107
+ static inline void impl ( type_t * outSubgroupData , const type_t * subgroupData, const uint32_t clampedSubgroupSize )
104
108
{
105
- uint32_t subgroupInvocationID,pseudoSubgroupID;
106
- const type_t * subgroupData = getSubgroupData (subgroupInvocationID,pseudoSubgroupID,workgroupData,localInvocationIndex,subgroupSize,workgroupSize);
107
- type_t retval = subgroupData[0 ];
108
- for (auto i=1u ; i<core::min<uint32_t >(subgroupSize,workgroupSize-pseudoSubgroupID); i++)
109
- retval = OP ()(retval,subgroupData[i]);
110
- return retval;
109
+ type_t red = subgroupData[0 ];
110
+ for (auto i=1u ; i<clampedSubgroupSize; i++)
111
+ red = OP ()(red,subgroupData[i]);
112
+ std::fill (outSubgroupData,outSubgroupData+clampedSubgroupSize,red);
111
113
}
112
114
113
115
_IRR_STATIC_INLINE_CONSTEXPR const char * name = " subgroup reduction" ;
114
116
};
115
117
template <class OP >
116
- struct emulatedSubgroupScanExclusive : emulatedSubgroupCommon<typename OP::type_t >
118
+ struct emulatedSubgroupScanExclusive : emulatedSubgroupCommon<emulatedSubgroupScanExclusive<OP>, typename OP::type_t >
117
119
{
118
120
using type_t = typename OP::type_t ;
119
121
120
- inline type_t operator ()( const type_t * workgroupData , const uint32_t localInvocationIndex, uint32_t subgroupSize, uint32_t workgroupSize )
122
+ static inline void impl ( type_t * outSubgroupData , const type_t * subgroupData, const uint32_t clampedSubgroupSize )
121
123
{
122
- uint32_t subgroupInvocationID,dummy;
123
- const type_t * subgroupData = getSubgroupData (subgroupInvocationID,dummy,workgroupData,localInvocationIndex,subgroupSize,workgroupSize);
124
- type_t retval = OP::IdentityElement;
125
- for (auto i=0u ; i<subgroupInvocationID; i++)
126
- retval = OP ()(retval, subgroupData[i]);
127
- return retval;
124
+ outSubgroupData[0u ] = OP::IdentityElement;
125
+ for (auto i=1u ; i<clampedSubgroupSize; i++)
126
+ outSubgroupData[i] = OP ()(outSubgroupData[i-1u ],subgroupData[i-1u ]);
128
127
}
129
128
130
129
_IRR_STATIC_INLINE_CONSTEXPR const char * name = " subgroup exclusive scan" ;
131
130
};
132
131
template <class OP >
133
- struct emulatedSubgroupScanInclusive : emulatedSubgroupCommon<typename OP::type_t >
132
+ struct emulatedSubgroupScanInclusive : emulatedSubgroupCommon<emulatedSubgroupScanInclusive<OP>, typename OP::type_t >
134
133
{
135
134
using type_t = typename OP::type_t ;
136
135
137
- inline type_t operator ()( const type_t * workgroupData , const uint32_t localInvocationIndex, uint32_t subgroupSize, uint32_t workgroupSize )
136
+ static inline void impl ( type_t * outSubgroupData , const type_t * subgroupData, const uint32_t clampedSubgroupSize )
138
137
{
139
- uint32_t subgroupInvocationID,dummy;
140
- const type_t * subgroupData = getSubgroupData (subgroupInvocationID,dummy,workgroupData,localInvocationIndex,subgroupSize,workgroupSize);
141
- type_t retval = OP::IdentityElement;
142
- for (auto i=0u ; i<=subgroupInvocationID; i++)
143
- retval = OP ()(retval, subgroupData[i]);
144
- return retval;
138
+ outSubgroupData[0u ] = subgroupData[0u ];
139
+ for (auto i=1u ; i<clampedSubgroupSize; i++)
140
+ outSubgroupData[i] = OP ()(outSubgroupData[i-1u ],subgroupData[i]);
145
141
}
146
142
147
143
_IRR_STATIC_INLINE_CONSTEXPR const char * name = " subgroup inclusive scan" ;
@@ -153,12 +149,12 @@ struct emulatedWorkgroupReduction
153
149
{
154
150
using type_t = typename OP::type_t ;
155
151
156
- inline type_t operator ()(const type_t * workgroupData , const uint32_t localInvocationIndex , uint32_t subgroupSize , uint32_t workgroupSize )
152
+ inline void operator ()(type_t * outputData , const type_t * workgroupData , uint32_t workgroupSize , uint32_t subgroupSize )
157
153
{
158
- type_t retval = workgroupData[0 ];
154
+ type_t red = workgroupData[0 ];
159
155
for (auto i=1u ; i<workgroupSize; i++)
160
- retval = OP ()(retval ,workgroupData[i]);
161
- return retval ;
156
+ red = OP ()(red ,workgroupData[i]);
157
+ std::fill (outputData,outputData+workgroupSize,red) ;
162
158
}
163
159
164
160
_IRR_STATIC_INLINE_CONSTEXPR const char * name = " workgroup reduction" ;
@@ -168,12 +164,11 @@ struct emulatedWorkgroupScanExclusive
168
164
{
169
165
using type_t = typename OP::type_t ;
170
166
171
- inline type_t operator ()(const type_t * workgroupData , const uint32_t localInvocationIndex , uint32_t subgroupSize , uint32_t workgroupSize )
167
+ inline void operator ()(type_t * outputData , const type_t * workgroupData , uint32_t workgroupSize , uint32_t subgroupSize )
172
168
{
173
- type_t retval = OP::IdentityElement;
174
- for (auto i=0u ; i<localInvocationIndex; i++)
175
- retval = OP ()(retval,workgroupData[i]);
176
- return retval;
169
+ outputData[0u ] = OP::IdentityElement;
170
+ for (auto i=1u ; i<workgroupSize; i++)
171
+ outputData[i] = OP ()(outputData[i-1u ],workgroupData[i-1u ]);
177
172
}
178
173
179
174
_IRR_STATIC_INLINE_CONSTEXPR const char * name = " workgroup exclusive scan" ;
@@ -183,12 +178,11 @@ struct emulatedWorkgroupScanInclusive
183
178
{
184
179
using type_t = typename OP::type_t ;
185
180
186
- inline type_t operator ()(const type_t * workgroupData , const uint32_t localInvocationIndex , uint32_t subgroupSize , uint32_t workgroupSize )
181
+ inline void operator ()(type_t * outputData , const type_t * workgroupData , uint32_t workgroupSize , uint32_t subgroupSize )
187
182
{
188
- type_t retval = OP::IdentityElement;
189
- for (auto i=0u ; i<=localInvocationIndex; i++)
190
- retval = OP ()(retval,workgroupData[i]);
191
- return retval;
183
+ outputData[0u ] = workgroupData[0u ];
184
+ for (auto i=1u ; i<workgroupSize; i++)
185
+ outputData[i] = OP ()(outputData[i-1u ],workgroupData[i]);
192
186
}
193
187
194
188
_IRR_STATIC_INLINE_CONSTEXPR const char * name = " workgroup inclusive scan" ;
@@ -232,21 +226,21 @@ bool validateResults(video::IVideoDriver* driver, const uint32_t* inputData, con
232
226
auto dataFromBuffer = reinterpret_cast <uint32_t *>(reinterpret_cast <uint8_t *>(downloadStagingArea->getBufferPointer ())+address);
233
227
234
228
// now check if the data obtained has valid values
229
+ constexpr uint32_t subgroupSize = 4u ;
230
+ uint32_t * tmp = new uint32_t [workgroupSize];
235
231
for (uint32_t workgroupID=0u ; success&&workgroupID<workgroupCount; workgroupID++)
236
- for (uint32_t localInvocationIndex=0u ; localInvocationIndex<workgroupSize; localInvocationIndex++)
237
232
{
238
- constexpr uint32_t subgroupSize = 4u ;
239
-
240
233
const auto workgroupOffset = workgroupID*workgroupSize;
241
- uint32_t val = Arithmetic<OP<uint32_t >>()(inputData+workgroupOffset, localInvocationIndex, subgroupSize, workgroupSize );
242
- const auto invocationOffset = workgroupOffset+ localInvocationIndex;
243
- if (val !=dataFromBuffer[invocationOffset ])
234
+ Arithmetic<OP<uint32_t >>()(tmp, inputData+workgroupOffset,workgroupSize, subgroupSize);
235
+ for ( uint32_t localInvocationIndex= 0u ; localInvocationIndex<workgroupSize; localInvocationIndex++)
236
+ if (tmp[localInvocationIndex] !=dataFromBuffer[workgroupOffset+localInvocationIndex ])
244
237
{
245
238
os::Printer::log (" Failed test #" + std::to_string (workgroupSize) + " (" + Arithmetic<OP<uint32_t >>::name + " ) (" + OP<uint32_t >::name + " )" , ELL_ERROR);
246
239
success = false ;
247
240
break ;
248
241
}
249
242
}
243
+ delete[] tmp;
250
244
}
251
245
else
252
246
os::Printer::log (" Could not download the buffer from the GPU, fence not signalled!" , ELL_ERROR);
@@ -384,7 +378,7 @@ int main()
384
378
// max workgroup size is hardcoded to 1024
385
379
uint32_t totalFailCount = 0 ;
386
380
const auto ds = descriptorSet.get ();
387
- for (uint32_t workgroupSize=8u ; workgroupSize<=1024u ; workgroupSize++)
381
+ for (uint32_t workgroupSize=1u ; workgroupSize<=1024u ; workgroupSize++)
388
382
{
389
383
core::smart_refctd_ptr<IGPUComputePipeline> pipelines[kTestTypeCount ];
390
384
for (uint32_t i=0u ; i<kTestTypeCount ; i++)
@@ -398,8 +392,8 @@ int main()
398
392
passed = runTest<emulatedSubgroupScanExclusive>(driver,pipelines[1u ].get (),descriptorSet.get (),inputData,workgroupSize,buffers)&&passed;
399
393
passed = runTest<emulatedSubgroupScanInclusive>(driver,pipelines[2u ].get (),descriptorSet.get (),inputData,workgroupSize,buffers)&&passed;
400
394
passed = runTest<emulatedWorkgroupReduction>(driver,pipelines[3u ].get (),descriptorSet.get (),inputData,workgroupSize,buffers)&&passed;
401
- // passed = runTest<emulatedWorkgroupScanExclusive>(driver,pipelines[4u].get(),descriptorSet.get(),inputData,workgroupSize,buffers)&&passed;
402
- // passed = runTest<emulatedWorkgroupScanInclusive>(driver,pipelines[5u].get(),descriptorSet.get(),inputData,workgroupSize,buffers)&&passed;
395
+ passed = runTest<emulatedWorkgroupScanExclusive>(driver,pipelines[4u ].get (),descriptorSet.get (),inputData,workgroupSize,buffers)&&passed;
396
+ passed = runTest<emulatedWorkgroupScanInclusive>(driver,pipelines[5u ].get (),descriptorSet.get (),inputData,workgroupSize,buffers)&&passed;
403
397
404
398
if (passed)
405
399
os::Printer::log (" Passed test #" + std::to_string (workgroupSize), ELL_INFORMATION);
0 commit comments