Skip to content

Commit 3295c7f

Browse files
update error message for testBufferParticipantPattern
1 parent d1c48bf commit 3295c7f

File tree

1 file changed

+127
-0
lines changed

1 file changed

+127
-0
lines changed

lib/Support/Check.cpp

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
#include "llvm/Support/Error.h"
1717
#include "llvm/Support/raw_ostream.h"
1818
#include <cmath>
19+
#include <map>
1920
#include <sstream>
21+
#include <tuple>
2022

2123
constexpr uint16_t Float16BitSign = 0x8000;
2224
constexpr uint16_t Float16BitExp = 0x7c00;
@@ -277,6 +279,121 @@ static bool testBufferFloatULP(offloadtest::Buffer *B1, offloadtest::Buffer *B2,
277279
return false;
278280
}
279281

282+
static bool testBufferParticipantPattern(offloadtest::Buffer *B1, offloadtest::Buffer *B2,
283+
unsigned GroupSize, std::string &errorMsg) {
284+
// Expect 3 x uint32_t: (combinedId, maskLow, maskHigh)
285+
if (GroupSize == 0) {
286+
errorMsg = "Invalid GroupSize (must be > 0)";
287+
return false;
288+
}
289+
290+
// Basic structural checks similar to testBufferExact
291+
if (B1->ArraySize != B2->ArraySize || B1->size() != B2->size()) {
292+
errorMsg = "Mismatched buffer shape (ArraySize or per-chunk size differs)";
293+
return false;
294+
}
295+
296+
// We operate on 32-bit words
297+
if ((B1->size() % sizeof(uint32_t)) != 0) {
298+
errorMsg = "Chunk size is not a multiple of 4 bytes";
299+
return false;
300+
}
301+
if ((B2->size() % sizeof(uint32_t)) != 0) {
302+
errorMsg = "Expected chunk size is not a multiple of 4 bytes";
303+
return false;
304+
}
305+
306+
const uint32_t WordsPerChunk = static_cast<uint32_t>(B1->size() / sizeof(uint32_t));
307+
if (WordsPerChunk % GroupSize != 0) {
308+
errorMsg = "Words per chunk must be a multiple of GroupSize";
309+
return false;
310+
}
311+
312+
using PatternTuple = std::tuple<uint32_t, uint32_t, uint32_t>;
313+
std::map<PatternTuple, unsigned> actualPatterns;
314+
std::map<PatternTuple, unsigned> expectedPatterns;
315+
316+
auto read_u32 = [](const char* base, uint32_t wordIndex) -> uint32_t {
317+
uint32_t v;
318+
std::memcpy(&v, base + wordIndex * sizeof(uint32_t), sizeof(uint32_t));
319+
return v;
320+
};
321+
322+
// Accumulate patterns from all chunks
323+
auto B1It = B1->Data.begin();
324+
auto B2It = B2->Data.begin();
325+
for (; B1It != B1->Data.end() && B2It != B2->Data.end(); ++B1It, ++B2It) {
326+
const char* aBuf = B1It->get(); // unique_ptr<char[]> -> char*
327+
const char* eBuf = B2It->get();
328+
329+
for (uint32_t i = 0; i + GroupSize <= WordsPerChunk; i += GroupSize) {
330+
if (GroupSize == 3) {
331+
// Actual
332+
PatternTuple ap(read_u32(aBuf, i + 0),
333+
read_u32(aBuf, i + 1),
334+
read_u32(aBuf, i + 2));
335+
++actualPatterns[ap];
336+
337+
// Expected
338+
PatternTuple ep(read_u32(eBuf, i + 0),
339+
read_u32(eBuf, i + 1),
340+
read_u32(eBuf, i + 2));
341+
++expectedPatterns[ep];
342+
} else {
343+
// If you plan to support other group sizes later, handle here.
344+
}
345+
}
346+
}
347+
348+
// Compare pattern multisets
349+
std::stringstream ss;
350+
bool hasError = false;
351+
352+
if (actualPatterns.size() != expectedPatterns.size()) {
353+
ss << "Pattern kind count mismatch: actual has " << actualPatterns.size()
354+
<< " unique patterns, expected has " << expectedPatterns.size() << " unique patterns\n";
355+
hasError = true;
356+
}
357+
358+
// Missing / count-mismatched patterns
359+
for (const auto& [pattern, expCount] : expectedPatterns) {
360+
auto it = actualPatterns.find(pattern);
361+
if (it == actualPatterns.end()) {
362+
if (!hasError) ss << "Pattern differences found:\n";
363+
hasError = true;
364+
ss << " Missing pattern (combineId=" << std::get<0>(pattern)
365+
<< ", maskLow=0x" << std::hex << std::get<1>(pattern)
366+
<< ", maskHigh=0x" << std::get<2>(pattern) << std::dec
367+
<< ") - expected count: " << expCount << ", actual count: 0\n";
368+
} else if (it->second != expCount) {
369+
if (!hasError) ss << "Pattern differences found:\n";
370+
hasError = true;
371+
ss << " Pattern (combineId=" << std::get<0>(pattern)
372+
<< ", maskLow=0x" << std::hex << std::get<1>(pattern)
373+
<< ", maskHigh=0x" << std::get<2>(pattern) << std::dec
374+
<< ") - expected count: " << expCount << ", actual count: " << it->second << "\n";
375+
}
376+
}
377+
378+
// Unexpected patterns
379+
for (const auto& [pattern, actCount] : actualPatterns) {
380+
if (expectedPatterns.find(pattern) == expectedPatterns.end()) {
381+
if (!hasError) ss << "Pattern differences found:\n";
382+
hasError = true;
383+
ss << " Unexpected pattern (combineId=" << std::get<0>(pattern)
384+
<< ", maskLow=0x" << std::hex << std::get<1>(pattern)
385+
<< ", maskHigh=0x" << std::get<2>(pattern) << std::dec
386+
<< ") - expected count: 0, actual count: " << actCount << "\n";
387+
}
388+
}
389+
390+
if (hasError) {
391+
errorMsg = ss.str();
392+
return false;
393+
}
394+
return true;
395+
}
396+
280397
template <typename T>
281398
static std::string bitPatternAsHex64(const T &Val,
282399
offloadtest::Rule ComparisonRule) {
@@ -391,6 +508,16 @@ llvm::Error verifyResult(offloadtest::Result R) {
391508
case offloadtest::Rule::BufferFloatEpsilon: {
392509
if (testBufferFloatEpsilon(R.ActualPtr, R.ExpectedPtr, R.Epsilon, R.DM))
393510
return llvm::Error::success();
511+
break;
512+
}
513+
case offloadtest::Rule::BufferParticipantPattern: {
514+
std::string errorMsg;
515+
if (testBufferParticipantPattern(R.ActualPtr, R.ExpectedPtr, R.GroupSize, errorMsg))
516+
return llvm::Error::success();
517+
// Return error with detailed message
518+
return llvm::make_error<llvm::StringError>(
519+
"BufferParticipantPattern test failed for " + R.Name + ":\n" + errorMsg,
520+
std::error_code());
394521

395522
std::ostringstream Oss;
396523
Oss << std::defaultfloat << R.Epsilon;

0 commit comments

Comments
 (0)