Skip to content

Commit 4bfffe4

Browse files
authored
Add support for UAV Counters (#80)
* save work * Add support for UAV Counters * Address Comments * Address comments * format * hack: Skip metal in CI
1 parent 90196a8 commit 4bfffe4

File tree

5 files changed

+135
-11
lines changed

5 files changed

+135
-11
lines changed

include/Support/Pipeline.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ struct Buffer {
7979
std::unique_ptr<char[]> Data;
8080
size_t Size;
8181
OutputProperties OutputProps;
82+
uint32_t Counter;
8283

8384
uint32_t size() const { return Size; }
8485

@@ -130,6 +131,7 @@ struct Resource {
130131
DirectXBinding DXBinding;
131132
std::optional<VulkanBinding> VKBinding;
132133
Buffer *BufferPtr = nullptr;
134+
bool HasCounter;
133135

134136
bool isRaw() const {
135137
switch (Kind) {

lib/API/DX/Device.cpp

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,19 @@ static DXGI_FORMAT getRawDXFormat(Resource &R) {
8989
return DXGI_FORMAT_UNKNOWN;
9090
}
9191

92+
static uint32_t getUAVBufferSize(Resource &R) {
93+
return R.HasCounter
94+
? llvm::alignTo(R.size(), D3D12_UAV_COUNTER_PLACEMENT_ALIGNMENT) +
95+
sizeof(uint32_t)
96+
: R.size();
97+
}
98+
99+
static uint32_t getUAVBufferCounterOffset(Resource &R) {
100+
return R.HasCounter
101+
? llvm::alignTo(R.size(), D3D12_UAV_COUNTER_PLACEMENT_ALIGNMENT)
102+
: 0;
103+
}
104+
92105
namespace {
93106

94107
enum DXResourceKind { UAV, SRV, CBV };
@@ -449,9 +462,10 @@ class DXDevice : public offloadtest::Device {
449462
}
450463

451464
llvm::Expected<ResourceSet> createUAV(Resource &R, InvocationState &IS) {
452-
llvm::outs() << "Creating UAV: { Size = " << R.size() << ", Register = u"
465+
const uint32_t BufferSize = getUAVBufferSize(R);
466+
llvm::outs() << "Creating UAV: { Size = " << BufferSize << ", Register = u"
453467
<< R.DXBinding.Register << ", Space = " << R.DXBinding.Space
454-
<< " }\n";
468+
<< ", HasCounter = " << R.HasCounter << " }\n";
455469
ComPtr<ID3D12Resource> Buffer;
456470
ComPtr<ID3D12Resource> UploadBuffer;
457471
ComPtr<ID3D12Resource> ReadBackBuffer;
@@ -461,7 +475,7 @@ class DXDevice : public offloadtest::Device {
461475
const D3D12_RESOURCE_DESC ResDesc = {
462476
D3D12_RESOURCE_DIMENSION_BUFFER,
463477
0,
464-
R.size(),
478+
BufferSize,
465479
1,
466480
1,
467481
1,
@@ -480,7 +494,7 @@ class DXDevice : public offloadtest::Device {
480494
const D3D12_HEAP_PROPERTIES UploadHeapProp =
481495
CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_UPLOAD);
482496
const D3D12_RESOURCE_DESC UploadResDesc =
483-
CD3DX12_RESOURCE_DESC::Buffer(R.size());
497+
CD3DX12_RESOURCE_DESC::Buffer(BufferSize);
484498

485499
if (auto Err =
486500
HR::toError(Device->CreateCommittedResource(
@@ -495,7 +509,7 @@ class DXDevice : public offloadtest::Device {
495509
const D3D12_RESOURCE_DESC ReadBackResDesc = {
496510
D3D12_RESOURCE_DIMENSION_BUFFER,
497511
0,
498-
R.size(),
512+
BufferSize,
499513
1,
500514
1,
501515
1,
@@ -529,24 +543,27 @@ class DXDevice : public offloadtest::Device {
529543
ComPtr<ID3D12Resource> Buffer) {
530544
const uint32_t EltSize = R.getElementSize();
531545
const uint32_t NumElts = R.size() / EltSize;
546+
ID3D12Resource *CounterBuffer = R.HasCounter ? Buffer.Get() : nullptr;
547+
const uint32_t CounterOffset = getUAVBufferCounterOffset(R);
532548
DXGI_FORMAT const EltFormat =
533549
R.isRaw() ? getRawDXFormat(R)
534550
: getDXFormat(R.BufferPtr->Format, R.BufferPtr->Channels);
535551
const D3D12_UNORDERED_ACCESS_VIEW_DESC UAVDesc = {
536552
EltFormat,
537553
D3D12_UAV_DIMENSION_BUFFER,
538-
{D3D12_BUFFER_UAV{0, NumElts, R.isStructuredBuffer() ? EltSize : 0, 0,
539-
R.isByteAddressBuffer()
540-
? D3D12_BUFFER_UAV_FLAG_RAW
541-
: D3D12_BUFFER_UAV_FLAG_NONE}}};
554+
{D3D12_BUFFER_UAV{
555+
0, NumElts, R.isStructuredBuffer() ? EltSize : 0, CounterOffset,
556+
R.isByteAddressBuffer() ? D3D12_BUFFER_UAV_FLAG_RAW
557+
: D3D12_BUFFER_UAV_FLAG_NONE}}};
542558

543559
llvm::outs() << "UAV: HeapIdx = " << HeapIdx << " EltSize = " << EltSize
544-
<< " NumElts = " << NumElts << "\n";
560+
<< " NumElts = " << NumElts << " HasCounter = " << R.HasCounter
561+
<< "\n";
545562
D3D12_CPU_DESCRIPTOR_HANDLE UAVHandle =
546563
IS.DescHeap->GetCPUDescriptorHandleForHeapStart();
547564
UAVHandle.ptr += HeapIdx * Device->GetDescriptorHandleIncrementSize(
548565
D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV);
549-
Device->CreateUnorderedAccessView(Buffer.Get(), nullptr, &UAVDesc,
566+
Device->CreateUnorderedAccessView(Buffer.Get(), CounterBuffer, &UAVDesc,
550567
UAVHandle);
551568
}
552569

@@ -895,6 +912,11 @@ class DXDevice : public offloadtest::Device {
895912
"Failed to map result."))
896913
return Err;
897914
memcpy(R.first->BufferPtr->Data.get(), DataPtr, R.first->size());
915+
if (R.first->HasCounter)
916+
memcpy(&R.first->BufferPtr->Counter,
917+
static_cast<char *>(DataPtr) +
918+
getUAVBufferCounterOffset(*R.first),
919+
sizeof(uint32_t));
898920
R.second.Readback->Unmap(0, nullptr);
899921
return llvm::Error::success();
900922
};

lib/Support/Pipeline.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ void MappingTraits<offloadtest::Buffer>::mapping(IO &I,
9898
I.mapRequired("Format", B.Format);
9999
I.mapOptional("Channels", B.Channels, 1);
100100
I.mapOptional("Stride", B.Stride, 0);
101+
I.mapOptional("Counter", B.Counter, 0);
101102
if (!I.outputting() && B.Stride != 0 && B.Channels != 1)
102103
I.setError("Cannot set a structure stride and more than one channel.");
103104
switch (B.Format) {
@@ -147,6 +148,7 @@ void MappingTraits<offloadtest::Resource>::mapping(IO &I,
147148
offloadtest::Resource &R) {
148149
I.mapRequired("Name", R.Name);
149150
I.mapRequired("Kind", R.Kind);
151+
I.mapOptional("HasCounter", R.HasCounter, 0);
150152
I.mapRequired("DirectXBinding", R.DXBinding);
151153
I.mapOptional("VulkanBinding", R.VKBinding);
152154
}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#--- source.hlsl
2+
RWStructuredBuffer<uint> Out : register(u0);
3+
4+
[numthreads(1,1,1)]
5+
void main(uint GI : SV_GroupIndex) {
6+
Out.DecrementCounter();
7+
Out.DecrementCounter();
8+
Out.DecrementCounter();
9+
Out[GI] = Out.DecrementCounter();
10+
}
11+
12+
//--- pipeline.yaml
13+
---
14+
Shaders:
15+
- Stage: Compute
16+
Entry: main
17+
DispatchSize: [1, 1, 1]
18+
Buffers:
19+
- Name: Out
20+
Format: Hex32
21+
Stride: 4
22+
ZeroInitSize: 4
23+
DescriptorSets:
24+
- Resources:
25+
- Name: Out
26+
Kind: RWStructuredBuffer
27+
HasCounter: true
28+
DirectXBinding:
29+
Register: 0
30+
Space: 0
31+
...
32+
#--- end
33+
34+
# UNSUPPORTED: Vulkan
35+
# UNSUPPORTED: Metal
36+
# UNSUPPORTED: Clang
37+
38+
# RUN: split-file %s %t
39+
# RUN: %dxc_target -T cs_6_0 -Fo %t.o %t/source.hlsl
40+
# RUN: %offloader %t/pipeline.yaml %t.o | FileCheck %s
41+
42+
# CHECK: Creating UAV: { Size = 4100, Register = u0, Space = 0, HasCounter = 1 }
43+
# CHECK: UAV: HeapIdx = 0 EltSize = 4 NumElts = 1 HasCounter = 1
44+
45+
# CHECK: Name: Out
46+
# CHECK: Counter: 4294967292
47+
# CHECK: Data: [
48+
# CHECK: 0xFFFFFFFC
49+
# CHECK: ]
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#--- source.hlsl
2+
RWStructuredBuffer<int> Out : register(u0);
3+
4+
[numthreads(1,1,1)]
5+
void main(uint GI : SV_GroupIndex) {
6+
Out.IncrementCounter();
7+
Out.IncrementCounter();
8+
Out.IncrementCounter();
9+
Out[GI] = Out.IncrementCounter();
10+
}
11+
12+
//--- pipeline.yaml
13+
---
14+
Shaders:
15+
- Stage: Compute
16+
Entry: main
17+
DispatchSize: [1, 1, 1]
18+
Buffers:
19+
- Name: Out
20+
Format: Hex32
21+
Stride: 4
22+
ZeroInitSize: 4
23+
DescriptorSets:
24+
- Resources:
25+
- Name: Out
26+
Kind: RWStructuredBuffer
27+
HasCounter: true
28+
DirectXBinding:
29+
Register: 0
30+
Space: 0
31+
...
32+
#--- end
33+
34+
# UNSUPPORTED: Vulkan
35+
# UNSUPPORTED: Metal
36+
# UNSUPPORTED: Clang
37+
38+
# RUN: split-file %s %t
39+
# RUN: %dxc_target -T cs_6_0 -Fo %t.o %t/source.hlsl
40+
# RUN: %offloader %t/pipeline.yaml %t.o | FileCheck %s
41+
42+
# CHECK: Creating UAV: { Size = 4100, Register = u0, Space = 0, HasCounter = 1 }
43+
# CHECK: UAV: HeapIdx = 0 EltSize = 4 NumElts = 1 HasCounter = 1
44+
45+
# CHECK: Name: Out
46+
# CHECK: Counter: 4
47+
# CHECK: Data: [
48+
# CHECK: 0x3
49+
# CHECK: ]

0 commit comments

Comments
 (0)