Skip to content

Commit df5d12f

Browse files
update error message for testBufferParticipantPattern
1 parent ef4a1c5 commit df5d12f

File tree

1 file changed

+57
-11
lines changed

1 file changed

+57
-11
lines changed

lib/Support/Check.cpp

Lines changed: 57 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "llvm/Support/raw_ostream.h"
1717
#include <cmath>
1818
#include <map>
19+
#include <sstream>
1920
#include <tuple>
2021

2122
constexpr uint16_t Float16BitSign = 0x8000;
@@ -270,15 +271,19 @@ static bool testBufferFloatULP(offloadtest::Buffer *B1, offloadtest::Buffer *B2,
270271
}
271272

272273
static bool testBufferParticipantPattern(offloadtest::Buffer *B1, offloadtest::Buffer *B2,
273-
unsigned GroupSize) {
274+
unsigned GroupSize, std::string& errorMsg) {
274275
// B1 is actual, B2 is expected
275276
// GroupSize should be 3 for participant patterns (combinedId, maskLow, maskHigh)
276-
if (GroupSize == 0 || GroupSize > B1->size() || GroupSize > B2->size())
277+
if (GroupSize == 0 || GroupSize > B1->size() || GroupSize > B2->size()) {
278+
errorMsg = "Invalid GroupSize or buffer too small";
277279
return false;
280+
}
278281

279282
// Ensure buffer sizes are multiples of GroupSize
280-
if (B1->size() % GroupSize != 0 || B2->size() % GroupSize != 0)
283+
if (B1->size() % GroupSize != 0 || B2->size() % GroupSize != 0) {
284+
errorMsg = "Buffer sizes must be multiples of GroupSize";
281285
return false;
286+
}
282287

283288
// Parse patterns from both buffers
284289
using PatternTuple = std::tuple<uint32_t, uint32_t, uint32_t>;
@@ -303,14 +308,51 @@ static bool testBufferParticipantPattern(offloadtest::Buffer *B1, offloadtest::B
303308
}
304309
}
305310

306-
// Compare pattern counts
307-
if (actualPatterns.size() != expectedPatterns.size())
308-
return false;
309-
311+
// Compare pattern counts and collect differences
312+
std::stringstream ss;
313+
bool hasError = false;
314+
315+
if (actualPatterns.size() != expectedPatterns.size()) {
316+
ss << "Pattern count mismatch: actual has " << actualPatterns.size()
317+
<< " unique patterns, expected has " << expectedPatterns.size() << " unique patterns\n";
318+
hasError = true;
319+
}
320+
321+
// Check for missing patterns
310322
for (const auto& [pattern, count] : expectedPatterns) {
311323
auto it = actualPatterns.find(pattern);
312-
if (it == actualPatterns.end() || it->second != count)
313-
return false;
324+
if (it == actualPatterns.end()) {
325+
if (!hasError) ss << "Pattern differences found:\n";
326+
hasError = true;
327+
ss << " Missing pattern (combineId=" << std::get<0>(pattern)
328+
<< ", maskLow=0x" << std::hex << std::get<1>(pattern)
329+
<< ", maskHigh=0x" << std::get<2>(pattern) << std::dec
330+
<< ") - expected count: " << count << ", actual count: 0\n";
331+
} else if (it->second != count) {
332+
if (!hasError) ss << "Pattern differences found:\n";
333+
hasError = true;
334+
ss << " Pattern (combineId=" << std::get<0>(pattern)
335+
<< ", maskLow=0x" << std::hex << std::get<1>(pattern)
336+
<< ", maskHigh=0x" << std::get<2>(pattern) << std::dec
337+
<< ") - expected count: " << count << ", actual count: " << it->second << "\n";
338+
}
339+
}
340+
341+
// Check for unexpected patterns
342+
for (const auto& [pattern, count] : actualPatterns) {
343+
if (expectedPatterns.find(pattern) == expectedPatterns.end()) {
344+
if (!hasError) ss << "Pattern differences found:\n";
345+
hasError = true;
346+
ss << " Unexpected pattern (combineId=" << std::get<0>(pattern)
347+
<< ", maskLow=0x" << std::hex << std::get<1>(pattern)
348+
<< ", maskHigh=0x" << std::get<2>(pattern) << std::dec
349+
<< ") - expected count: 0, actual count: " << count << "\n";
350+
}
351+
}
352+
353+
if (hasError) {
354+
errorMsg = ss.str();
355+
return false;
314356
}
315357

316358
return true;
@@ -334,9 +376,13 @@ llvm::Error verifyResult(offloadtest::Result R) {
334376
break;
335377
}
336378
case offloadtest::Rule::BufferParticipantPattern: {
337-
if (testBufferParticipantPattern(R.ActualPtr, R.ExpectedPtr, R.GroupSize))
379+
std::string errorMsg;
380+
if (testBufferParticipantPattern(R.ActualPtr, R.ExpectedPtr, R.GroupSize, errorMsg))
338381
return llvm::Error::success();
339-
break;
382+
// Return error with detailed message
383+
return llvm::make_error<llvm::StringError>(
384+
"BufferParticipantPattern test failed for " + R.Name + ":\n" + errorMsg,
385+
std::error_code());
340386
}
341387
}
342388
llvm::SmallString<256> Str;

0 commit comments

Comments
 (0)