16
16
#include " llvm/Support/raw_ostream.h"
17
17
#include < cmath>
18
18
#include < map>
19
+ #include < sstream>
19
20
#include < tuple>
20
21
21
22
constexpr uint16_t Float16BitSign = 0x8000 ;
@@ -270,15 +271,19 @@ static bool testBufferFloatULP(offloadtest::Buffer *B1, offloadtest::Buffer *B2,
270
271
}
271
272
272
273
static bool testBufferParticipantPattern (offloadtest::Buffer *B1, offloadtest::Buffer *B2,
273
- unsigned GroupSize) {
274
+ unsigned GroupSize, std::string& errorMsg ) {
274
275
// B1 is actual, B2 is expected
275
276
// GroupSize should be 3 for participant patterns (combinedId, maskLow, maskHigh)
276
- if (GroupSize == 0 || GroupSize > B1->size () || GroupSize > B2->size ())
277
+ if (GroupSize == 0 || GroupSize > B1->size () || GroupSize > B2->size ()) {
278
+ errorMsg = " Invalid GroupSize or buffer too small" ;
277
279
return false ;
280
+ }
278
281
279
282
// Ensure buffer sizes are multiples of GroupSize
280
- if (B1->size () % GroupSize != 0 || B2->size () % GroupSize != 0 )
283
+ if (B1->size () % GroupSize != 0 || B2->size () % GroupSize != 0 ) {
284
+ errorMsg = " Buffer sizes must be multiples of GroupSize" ;
281
285
return false ;
286
+ }
282
287
283
288
// Parse patterns from both buffers
284
289
using PatternTuple = std::tuple<uint32_t , uint32_t , uint32_t >;
@@ -303,14 +308,51 @@ static bool testBufferParticipantPattern(offloadtest::Buffer *B1, offloadtest::B
303
308
}
304
309
}
305
310
306
- // Compare pattern counts
307
- if (actualPatterns.size () != expectedPatterns.size ())
308
- return false ;
309
-
311
+ // Compare pattern counts and collect differences
312
+ std::stringstream ss;
313
+ bool hasError = false ;
314
+
315
+ if (actualPatterns.size () != expectedPatterns.size ()) {
316
+ ss << " Pattern count mismatch: actual has " << actualPatterns.size ()
317
+ << " unique patterns, expected has " << expectedPatterns.size () << " unique patterns\n " ;
318
+ hasError = true ;
319
+ }
320
+
321
+ // Check for missing patterns
310
322
for (const auto & [pattern, count] : expectedPatterns) {
311
323
auto it = actualPatterns.find (pattern);
312
- if (it == actualPatterns.end () || it->second != count)
313
- return false ;
324
+ if (it == actualPatterns.end ()) {
325
+ if (!hasError) ss << " Pattern differences found:\n " ;
326
+ hasError = true ;
327
+ ss << " Missing pattern (combineId=" << std::get<0 >(pattern)
328
+ << " , maskLow=0x" << std::hex << std::get<1 >(pattern)
329
+ << " , maskHigh=0x" << std::get<2 >(pattern) << std::dec
330
+ << " ) - expected count: " << count << " , actual count: 0\n " ;
331
+ } else if (it->second != count) {
332
+ if (!hasError) ss << " Pattern differences found:\n " ;
333
+ hasError = true ;
334
+ ss << " Pattern (combineId=" << std::get<0 >(pattern)
335
+ << " , maskLow=0x" << std::hex << std::get<1 >(pattern)
336
+ << " , maskHigh=0x" << std::get<2 >(pattern) << std::dec
337
+ << " ) - expected count: " << count << " , actual count: " << it->second << " \n " ;
338
+ }
339
+ }
340
+
341
+ // Check for unexpected patterns
342
+ for (const auto & [pattern, count] : actualPatterns) {
343
+ if (expectedPatterns.find (pattern) == expectedPatterns.end ()) {
344
+ if (!hasError) ss << " Pattern differences found:\n " ;
345
+ hasError = true ;
346
+ ss << " Unexpected pattern (combineId=" << std::get<0 >(pattern)
347
+ << " , maskLow=0x" << std::hex << std::get<1 >(pattern)
348
+ << " , maskHigh=0x" << std::get<2 >(pattern) << std::dec
349
+ << " ) - expected count: 0, actual count: " << count << " \n " ;
350
+ }
351
+ }
352
+
353
+ if (hasError) {
354
+ errorMsg = ss.str ();
355
+ return false ;
314
356
}
315
357
316
358
return true ;
@@ -334,9 +376,13 @@ llvm::Error verifyResult(offloadtest::Result R) {
334
376
break ;
335
377
}
336
378
case offloadtest::Rule::BufferParticipantPattern: {
337
- if (testBufferParticipantPattern (R.ActualPtr , R.ExpectedPtr , R.GroupSize ))
379
+ std::string errorMsg;
380
+ if (testBufferParticipantPattern (R.ActualPtr , R.ExpectedPtr , R.GroupSize , errorMsg))
338
381
return llvm::Error::success ();
339
- break ;
382
+ // Return error with detailed message
383
+ return llvm::make_error<llvm::StringError>(
384
+ " BufferParticipantPattern test failed for " + R.Name + " :\n " + errorMsg,
385
+ std::error_code ());
340
386
}
341
387
}
342
388
llvm::SmallString<256 > Str;
0 commit comments