15
15
#include " llvm/Support/Error.h"
16
16
#include " llvm/Support/raw_ostream.h"
17
17
#include < cmath>
18
+ #include < map>
19
+ #include < tuple>
18
20
19
21
constexpr uint16_t Float16BitSign = 0x8000 ;
20
22
constexpr uint16_t Float16BitExp = 0x7c00 ;
@@ -267,6 +269,53 @@ static bool testBufferFloatULP(offloadtest::Buffer *B1, offloadtest::Buffer *B2,
267
269
return false ;
268
270
}
269
271
272
+ static bool testBufferParticipantPattern (offloadtest::Buffer *B1, offloadtest::Buffer *B2,
273
+ unsigned GroupSize) {
274
+ // B1 is actual, B2 is expected
275
+ // GroupSize should be 3 for participant patterns (combinedId, maskLow, maskHigh)
276
+ if (GroupSize == 0 || GroupSize > B1->size () || GroupSize > B2->size ())
277
+ return false ;
278
+
279
+ // Ensure buffer sizes are multiples of GroupSize
280
+ if (B1->size () % GroupSize != 0 || B2->size () % GroupSize != 0 )
281
+ return false ;
282
+
283
+ // Parse patterns from both buffers
284
+ using PatternTuple = std::tuple<uint32_t , uint32_t , uint32_t >;
285
+ std::map<PatternTuple, unsigned > actualPatterns;
286
+ std::map<PatternTuple, unsigned > expectedPatterns;
287
+
288
+ // Count patterns in actual buffer
289
+ const uint32_t * actualData = reinterpret_cast <const uint32_t *>(B1->Data .get ());
290
+ for (size_t i = 0 ; i < B1->size () / sizeof (uint32_t ); i += GroupSize) {
291
+ if (GroupSize == 3 ) {
292
+ PatternTuple pattern (actualData[i], actualData[i+1 ], actualData[i+2 ]);
293
+ actualPatterns[pattern]++;
294
+ }
295
+ }
296
+
297
+ // Count patterns in expected buffer
298
+ const uint32_t * expectedData = reinterpret_cast <const uint32_t *>(B2->Data .get ());
299
+ for (size_t i = 0 ; i < B2->size () / sizeof (uint32_t ); i += GroupSize) {
300
+ if (GroupSize == 3 ) {
301
+ PatternTuple pattern (expectedData[i], expectedData[i+1 ], expectedData[i+2 ]);
302
+ expectedPatterns[pattern]++;
303
+ }
304
+ }
305
+
306
+ // Compare pattern counts
307
+ if (actualPatterns.size () != expectedPatterns.size ())
308
+ return false ;
309
+
310
+ for (const auto & [pattern, count] : expectedPatterns) {
311
+ auto it = actualPatterns.find (pattern);
312
+ if (it == actualPatterns.end () || it->second != count)
313
+ return false ;
314
+ }
315
+
316
+ return true ;
317
+ }
318
+
270
319
llvm::Error verifyResult (offloadtest::Result R) {
271
320
switch (R.Rule ) {
272
321
case offloadtest::Rule::BufferExact: {
@@ -284,6 +333,11 @@ llvm::Error verifyResult(offloadtest::Result R) {
284
333
return llvm::Error::success ();
285
334
break ;
286
335
}
336
+ case offloadtest::Rule::BufferParticipantPattern: {
337
+ if (testBufferParticipantPattern (R.ActualPtr , R.ExpectedPtr , R.GroupSize ))
338
+ return llvm::Error::success ();
339
+ break ;
340
+ }
287
341
}
288
342
llvm::SmallString<256 > Str;
289
343
llvm::raw_svector_ostream OS (Str);
0 commit comments