Skip to content

Commit 15fd7f1

Browse files
add new buffer match rule to work with bittracking
1 parent 5136e82 commit 15fd7f1

File tree

4 files changed

+107
-1
lines changed

4 files changed

+107
-1
lines changed
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Example test case using BufferParticipantPattern rule for WaveParticipantBitTracking
2+
---
3+
Shaders:
4+
- Stage: Compute
5+
Entry: main
6+
DispatchSize: [8, 1, 1] # 8 threads
7+
Buffers:
8+
# Actual output from the shader (may have patterns in any order)
9+
- Name: ParticipantOutput
10+
Format: UInt32
11+
ZeroInitSize: 96 # Space for multiple patterns (3 uint32 per pattern)
12+
13+
# Expected patterns - order doesn't matter, but pattern counts must match
14+
- Name: ExpectedPatterns
15+
Format: UInt32
16+
Data: [
17+
# Pattern 1: Wave op ID 69, loop iteration 0, participants 0,1,2,3
18+
4416, 0x000F, 0x0000, # (69<<6)|0, mask for threads 0-3, high mask
19+
4416, 0x000F, 0x0000, # Duplicate 1
20+
4416, 0x000F, 0x0000, # Duplicate 2
21+
4416, 0x000F, 0x0000, # Duplicate 3 (4 participants = 4 copies)
22+
23+
# Pattern 2: Wave op ID 70, loop iteration 1, participants 4,5,6,7
24+
4496, 0x00F0, 0x0000, # (70<<6)|(1<<4), mask for threads 4-7, high mask
25+
4496, 0x00F0, 0x0000, # Duplicate 1
26+
4496, 0x00F0, 0x0000, # Duplicate 2
27+
4496, 0x00F0, 0x0000, # Duplicate 3 (4 participants = 4 copies)
28+
]
29+
30+
Results:
31+
- Result: ValidateParticipantPatterns
32+
Rule: BufferParticipantPattern
33+
GroupSize: 3 # Each pattern consists of 3 uint32 values
34+
Actual: ParticipantOutput
35+
Expected: ExpectedPatterns
36+
37+
DescriptorSets:
38+
- Resources:
39+
- Name: ParticipantOutput
40+
Kind: RWBuffer
41+
DirectXBinding:
42+
Register: 0
43+
Space: 0
44+
VulkanBinding:
45+
Binding: 0
46+
...

include/Support/Pipeline.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ namespace offloadtest {
2525

2626
enum class Stages { Compute };
2727

28-
enum class Rule { BufferExact, BufferFloatULP, BufferFloatEpsilon };
28+
enum class Rule { BufferExact, BufferFloatULP, BufferFloatEpsilon, BufferParticipantPattern };
2929

3030
enum class DenormMode { Any, FTZ, Preserve };
3131

@@ -124,6 +124,7 @@ struct Result {
124124
DenormMode DM = DenormMode::Any;
125125
unsigned ULPT; // ULP Tolerance
126126
double Epsilon;
127+
unsigned GroupSize = 0; // For BufferParticipantPattern rule
127128
};
128129

129130
struct Resource {
@@ -319,6 +320,7 @@ template <> struct ScalarEnumerationTraits<offloadtest::Rule> {
319320
ENUM_CASE(BufferExact);
320321
ENUM_CASE(BufferFloatULP);
321322
ENUM_CASE(BufferFloatEpsilon);
323+
ENUM_CASE(BufferParticipantPattern);
322324
#undef ENUM_CASE
323325
}
324326
};

lib/Support/Check.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
#include "llvm/Support/Error.h"
1616
#include "llvm/Support/raw_ostream.h"
1717
#include <cmath>
18+
#include <map>
19+
#include <tuple>
1820

1921
constexpr uint16_t Float16BitSign = 0x8000;
2022
constexpr uint16_t Float16BitExp = 0x7c00;
@@ -267,6 +269,53 @@ static bool testBufferFloatULP(offloadtest::Buffer *B1, offloadtest::Buffer *B2,
267269
return false;
268270
}
269271

272+
static bool testBufferParticipantPattern(offloadtest::Buffer *B1, offloadtest::Buffer *B2,
273+
unsigned GroupSize) {
274+
// B1 is actual, B2 is expected
275+
// GroupSize should be 3 for participant patterns (combinedId, maskLow, maskHigh)
276+
if (GroupSize == 0 || GroupSize > B1->size() || GroupSize > B2->size())
277+
return false;
278+
279+
// Ensure buffer sizes are multiples of GroupSize
280+
if (B1->size() % GroupSize != 0 || B2->size() % GroupSize != 0)
281+
return false;
282+
283+
// Parse patterns from both buffers
284+
using PatternTuple = std::tuple<uint32_t, uint32_t, uint32_t>;
285+
std::map<PatternTuple, unsigned> actualPatterns;
286+
std::map<PatternTuple, unsigned> expectedPatterns;
287+
288+
// Count patterns in actual buffer
289+
const uint32_t* actualData = reinterpret_cast<const uint32_t*>(B1->Data.get());
290+
for (size_t i = 0; i < B1->size() / sizeof(uint32_t); i += GroupSize) {
291+
if (GroupSize == 3) {
292+
PatternTuple pattern(actualData[i], actualData[i+1], actualData[i+2]);
293+
actualPatterns[pattern]++;
294+
}
295+
}
296+
297+
// Count patterns in expected buffer
298+
const uint32_t* expectedData = reinterpret_cast<const uint32_t*>(B2->Data.get());
299+
for (size_t i = 0; i < B2->size() / sizeof(uint32_t); i += GroupSize) {
300+
if (GroupSize == 3) {
301+
PatternTuple pattern(expectedData[i], expectedData[i+1], expectedData[i+2]);
302+
expectedPatterns[pattern]++;
303+
}
304+
}
305+
306+
// Compare pattern counts
307+
if (actualPatterns.size() != expectedPatterns.size())
308+
return false;
309+
310+
for (const auto& [pattern, count] : expectedPatterns) {
311+
auto it = actualPatterns.find(pattern);
312+
if (it == actualPatterns.end() || it->second != count)
313+
return false;
314+
}
315+
316+
return true;
317+
}
318+
270319
llvm::Error verifyResult(offloadtest::Result R) {
271320
switch (R.Rule) {
272321
case offloadtest::Rule::BufferExact: {
@@ -284,6 +333,11 @@ llvm::Error verifyResult(offloadtest::Result R) {
284333
return llvm::Error::success();
285334
break;
286335
}
336+
case offloadtest::Rule::BufferParticipantPattern: {
337+
if (testBufferParticipantPattern(R.ActualPtr, R.ExpectedPtr, R.GroupSize))
338+
return llvm::Error::success();
339+
break;
340+
}
287341
}
288342
llvm::SmallString<256> Str;
289343
llvm::raw_svector_ostream OS(Str);

lib/Support/Pipeline.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,10 @@ void MappingTraits<offloadtest::Result>::mapping(IO &I,
255255
I.mapOptional("DenormMode", R.DM);
256256
break;
257257
}
258+
case Rule::BufferParticipantPattern: {
259+
I.mapRequired("GroupSize", R.GroupSize);
260+
break;
261+
}
258262
default:
259263
break;
260264
}

0 commit comments

Comments
 (0)