|
16 | 16 | #include "llvm/Support/Error.h"
|
17 | 17 | #include "llvm/Support/raw_ostream.h"
|
18 | 18 | #include <cmath>
|
| 19 | +#include <map> |
19 | 20 | #include <sstream>
|
| 21 | +#include <tuple> |
20 | 22 |
|
21 | 23 | constexpr uint16_t Float16BitSign = 0x8000;
|
22 | 24 | constexpr uint16_t Float16BitExp = 0x7c00;
|
@@ -277,6 +279,121 @@ static bool testBufferFloatULP(offloadtest::Buffer *B1, offloadtest::Buffer *B2,
|
277 | 279 | return false;
|
278 | 280 | }
|
279 | 281 |
|
| 282 | +static bool testBufferParticipantPattern(offloadtest::Buffer *B1, offloadtest::Buffer *B2, |
| 283 | + unsigned GroupSize, std::string &errorMsg) { |
| 284 | + // Expect 3 x uint32_t: (combinedId, maskLow, maskHigh) |
| 285 | + if (GroupSize == 0) { |
| 286 | + errorMsg = "Invalid GroupSize (must be > 0)"; |
| 287 | + return false; |
| 288 | + } |
| 289 | + |
| 290 | + // Basic structural checks similar to testBufferExact |
| 291 | + if (B1->ArraySize != B2->ArraySize || B1->size() != B2->size()) { |
| 292 | + errorMsg = "Mismatched buffer shape (ArraySize or per-chunk size differs)"; |
| 293 | + return false; |
| 294 | + } |
| 295 | + |
| 296 | + // We operate on 32-bit words |
| 297 | + if ((B1->size() % sizeof(uint32_t)) != 0) { |
| 298 | + errorMsg = "Chunk size is not a multiple of 4 bytes"; |
| 299 | + return false; |
| 300 | + } |
| 301 | + if ((B2->size() % sizeof(uint32_t)) != 0) { |
| 302 | + errorMsg = "Expected chunk size is not a multiple of 4 bytes"; |
| 303 | + return false; |
| 304 | + } |
| 305 | + |
| 306 | + const uint32_t WordsPerChunk = static_cast<uint32_t>(B1->size() / sizeof(uint32_t)); |
| 307 | + if (WordsPerChunk % GroupSize != 0) { |
| 308 | + errorMsg = "Words per chunk must be a multiple of GroupSize"; |
| 309 | + return false; |
| 310 | + } |
| 311 | + |
| 312 | + using PatternTuple = std::tuple<uint32_t, uint32_t, uint32_t>; |
| 313 | + std::map<PatternTuple, unsigned> actualPatterns; |
| 314 | + std::map<PatternTuple, unsigned> expectedPatterns; |
| 315 | + |
| 316 | + auto read_u32 = [](const char* base, uint32_t wordIndex) -> uint32_t { |
| 317 | + uint32_t v; |
| 318 | + std::memcpy(&v, base + wordIndex * sizeof(uint32_t), sizeof(uint32_t)); |
| 319 | + return v; |
| 320 | + }; |
| 321 | + |
| 322 | + // Accumulate patterns from all chunks |
| 323 | + auto B1It = B1->Data.begin(); |
| 324 | + auto B2It = B2->Data.begin(); |
| 325 | + for (; B1It != B1->Data.end() && B2It != B2->Data.end(); ++B1It, ++B2It) { |
| 326 | + const char* aBuf = B1It->get(); // unique_ptr<char[]> -> char* |
| 327 | + const char* eBuf = B2It->get(); |
| 328 | + |
| 329 | + for (uint32_t i = 0; i + GroupSize <= WordsPerChunk; i += GroupSize) { |
| 330 | + if (GroupSize == 3) { |
| 331 | + // Actual |
| 332 | + PatternTuple ap(read_u32(aBuf, i + 0), |
| 333 | + read_u32(aBuf, i + 1), |
| 334 | + read_u32(aBuf, i + 2)); |
| 335 | + ++actualPatterns[ap]; |
| 336 | + |
| 337 | + // Expected |
| 338 | + PatternTuple ep(read_u32(eBuf, i + 0), |
| 339 | + read_u32(eBuf, i + 1), |
| 340 | + read_u32(eBuf, i + 2)); |
| 341 | + ++expectedPatterns[ep]; |
| 342 | + } else { |
| 343 | + // If you plan to support other group sizes later, handle here. |
| 344 | + } |
| 345 | + } |
| 346 | + } |
| 347 | + |
| 348 | + // Compare pattern multisets |
| 349 | + std::stringstream ss; |
| 350 | + bool hasError = false; |
| 351 | + |
| 352 | + if (actualPatterns.size() != expectedPatterns.size()) { |
| 353 | + ss << "Pattern kind count mismatch: actual has " << actualPatterns.size() |
| 354 | + << " unique patterns, expected has " << expectedPatterns.size() << " unique patterns\n"; |
| 355 | + hasError = true; |
| 356 | + } |
| 357 | + |
| 358 | + // Missing / count-mismatched patterns |
| 359 | + for (const auto& [pattern, expCount] : expectedPatterns) { |
| 360 | + auto it = actualPatterns.find(pattern); |
| 361 | + if (it == actualPatterns.end()) { |
| 362 | + if (!hasError) ss << "Pattern differences found:\n"; |
| 363 | + hasError = true; |
| 364 | + ss << " Missing pattern (combineId=" << std::get<0>(pattern) |
| 365 | + << ", maskLow=0x" << std::hex << std::get<1>(pattern) |
| 366 | + << ", maskHigh=0x" << std::get<2>(pattern) << std::dec |
| 367 | + << ") - expected count: " << expCount << ", actual count: 0\n"; |
| 368 | + } else if (it->second != expCount) { |
| 369 | + if (!hasError) ss << "Pattern differences found:\n"; |
| 370 | + hasError = true; |
| 371 | + ss << " Pattern (combineId=" << std::get<0>(pattern) |
| 372 | + << ", maskLow=0x" << std::hex << std::get<1>(pattern) |
| 373 | + << ", maskHigh=0x" << std::get<2>(pattern) << std::dec |
| 374 | + << ") - expected count: " << expCount << ", actual count: " << it->second << "\n"; |
| 375 | + } |
| 376 | + } |
| 377 | + |
| 378 | + // Unexpected patterns |
| 379 | + for (const auto& [pattern, actCount] : actualPatterns) { |
| 380 | + if (expectedPatterns.find(pattern) == expectedPatterns.end()) { |
| 381 | + if (!hasError) ss << "Pattern differences found:\n"; |
| 382 | + hasError = true; |
| 383 | + ss << " Unexpected pattern (combineId=" << std::get<0>(pattern) |
| 384 | + << ", maskLow=0x" << std::hex << std::get<1>(pattern) |
| 385 | + << ", maskHigh=0x" << std::get<2>(pattern) << std::dec |
| 386 | + << ") - expected count: 0, actual count: " << actCount << "\n"; |
| 387 | + } |
| 388 | + } |
| 389 | + |
| 390 | + if (hasError) { |
| 391 | + errorMsg = ss.str(); |
| 392 | + return false; |
| 393 | + } |
| 394 | + return true; |
| 395 | +} |
| 396 | + |
280 | 397 | template <typename T>
|
281 | 398 | static std::string bitPatternAsHex64(const T &Val,
|
282 | 399 | offloadtest::Rule ComparisonRule) {
|
@@ -391,6 +508,16 @@ llvm::Error verifyResult(offloadtest::Result R) {
|
391 | 508 | case offloadtest::Rule::BufferFloatEpsilon: {
|
392 | 509 | if (testBufferFloatEpsilon(R.ActualPtr, R.ExpectedPtr, R.Epsilon, R.DM))
|
393 | 510 | return llvm::Error::success();
|
| 511 | + break; |
| 512 | + } |
| 513 | + case offloadtest::Rule::BufferParticipantPattern: { |
| 514 | + std::string errorMsg; |
| 515 | + if (testBufferParticipantPattern(R.ActualPtr, R.ExpectedPtr, R.GroupSize, errorMsg)) |
| 516 | + return llvm::Error::success(); |
| 517 | + // Return error with detailed message |
| 518 | + return llvm::make_error<llvm::StringError>( |
| 519 | + "BufferParticipantPattern test failed for " + R.Name + ":\n" + errorMsg, |
| 520 | + std::error_code()); |
394 | 521 |
|
395 | 522 | std::ostringstream Oss;
|
396 | 523 | Oss << std::defaultfloat << R.Epsilon;
|
|
0 commit comments