Skip to content

Commit a5187ec

Browse files
authored
Initial support for Texture2D types (#257)
This adds initial support for Texture2D and RWTexture2D across DirectX, Vulkan and Metal. Fixes #208
1 parent dc51835 commit a5187ec

File tree

7 files changed

+657
-184
lines changed

7 files changed

+657
-184
lines changed

include/Support/Pipeline.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,11 @@ enum class ResourceKind {
5050
Buffer,
5151
StructuredBuffer,
5252
ByteAddressBuffer,
53+
Texture2D,
5354
RWBuffer,
5455
RWStructuredBuffer,
5556
RWByteAddressBuffer,
57+
RWTexture2D,
5658
ConstantBuffer,
5759
};
5860

@@ -138,6 +140,8 @@ struct Resource {
138140
switch (Kind) {
139141
case ResourceKind::Buffer:
140142
case ResourceKind::RWBuffer:
143+
case ResourceKind::Texture2D:
144+
case ResourceKind::RWTexture2D:
141145
return false;
142146
case ResourceKind::StructuredBuffer:
143147
case ResourceKind::RWStructuredBuffer:
@@ -149,6 +153,23 @@ struct Resource {
149153
llvm_unreachable("All cases handled");
150154
}
151155

156+
bool isTexture() const {
157+
switch (Kind) {
158+
case ResourceKind::Buffer:
159+
case ResourceKind::RWBuffer:
160+
case ResourceKind::StructuredBuffer:
161+
case ResourceKind::RWStructuredBuffer:
162+
case ResourceKind::ByteAddressBuffer:
163+
case ResourceKind::RWByteAddressBuffer:
164+
case ResourceKind::ConstantBuffer:
165+
return false;
166+
case ResourceKind::Texture2D:
167+
case ResourceKind::RWTexture2D:
168+
return true;
169+
}
170+
llvm_unreachable("All cases handled");
171+
}
172+
152173
bool isByteAddressBuffer() const {
153174
switch (Kind) {
154175
case ResourceKind::ByteAddressBuffer:
@@ -178,15 +199,19 @@ struct Resource {
178199
case ResourceKind::Buffer:
179200
case ResourceKind::StructuredBuffer:
180201
case ResourceKind::ByteAddressBuffer:
202+
case ResourceKind::Texture2D:
181203
case ResourceKind::ConstantBuffer:
182204
return false;
183205
case ResourceKind::RWBuffer:
184206
case ResourceKind::RWStructuredBuffer:
185207
case ResourceKind::RWByteAddressBuffer:
208+
case ResourceKind::RWTexture2D:
186209
return true;
187210
}
188211
llvm_unreachable("All cases handled");
189212
}
213+
214+
bool isReadOnly() const { return !isReadWrite(); }
190215
};
191216

192217
struct DescriptorSet {
@@ -360,9 +385,11 @@ template <> struct ScalarEnumerationTraits<offloadtest::ResourceKind> {
360385
ENUM_CASE(Buffer);
361386
ENUM_CASE(StructuredBuffer);
362387
ENUM_CASE(ByteAddressBuffer);
388+
ENUM_CASE(Texture2D);
363389
ENUM_CASE(RWBuffer);
364390
ENUM_CASE(RWStructuredBuffer);
365391
ENUM_CASE(RWByteAddressBuffer);
392+
ENUM_CASE(RWTexture2D);
366393
ENUM_CASE(ConstantBuffer);
367394
#undef ENUM_CASE
368395
}

lib/API/DX/Device.cpp

Lines changed: 152 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ static DXGI_FORMAT getDXFormat(DataFormat Format, int Channels) {
7373
return DXGI_FORMAT_UNKNOWN;
7474
}
7575

76-
static DXGI_FORMAT getRawDXFormat(Resource &R) {
76+
static DXGI_FORMAT getRawDXFormat(const Resource &R) {
7777
if (!R.isByteAddressBuffer())
7878
return DXGI_FORMAT_UNKNOWN;
7979

@@ -89,33 +89,50 @@ static DXGI_FORMAT getRawDXFormat(Resource &R) {
8989
return DXGI_FORMAT_UNKNOWN;
9090
}
9191

92-
static uint32_t getUAVBufferSize(Resource &R) {
92+
static uint32_t getUAVBufferSize(const Resource &R) {
9393
return R.HasCounter
9494
? llvm::alignTo(R.size(), D3D12_UAV_COUNTER_PLACEMENT_ALIGNMENT) +
9595
sizeof(uint32_t)
9696
: R.size();
9797
}
9898

99-
static uint32_t getUAVBufferCounterOffset(Resource &R) {
99+
static uint32_t getUAVBufferCounterOffset(const Resource &R) {
100100
return R.HasCounter
101101
? llvm::alignTo(R.size(), D3D12_UAV_COUNTER_PLACEMENT_ALIGNMENT)
102102
: 0;
103103
}
104104

105-
namespace {
105+
static D3D12_RESOURCE_DIMENSION getDXDimension(ResourceKind RK) {
106+
switch (RK) {
107+
case ResourceKind::Buffer:
108+
case ResourceKind::StructuredBuffer:
109+
case ResourceKind::ByteAddressBuffer:
110+
case ResourceKind::RWStructuredBuffer:
111+
case ResourceKind::RWBuffer:
112+
case ResourceKind::RWByteAddressBuffer:
113+
case ResourceKind::ConstantBuffer:
114+
return D3D12_RESOURCE_DIMENSION_BUFFER;
115+
case ResourceKind::Texture2D:
116+
case ResourceKind::RWTexture2D:
117+
return D3D12_RESOURCE_DIMENSION_TEXTURE2D;
118+
}
119+
llvm_unreachable("All cases handled");
120+
}
106121

107122
enum DXResourceKind { UAV, SRV, CBV };
108123

109-
DXResourceKind getDXKind(offloadtest::ResourceKind RK) {
124+
static DXResourceKind getDXKind(offloadtest::ResourceKind RK) {
110125
switch (RK) {
111126
case ResourceKind::Buffer:
112127
case ResourceKind::StructuredBuffer:
113128
case ResourceKind::ByteAddressBuffer:
129+
case ResourceKind::Texture2D:
114130
return SRV;
115131

116132
case ResourceKind::RWStructuredBuffer:
117133
case ResourceKind::RWBuffer:
118134
case ResourceKind::RWByteAddressBuffer:
135+
case ResourceKind::RWTexture2D:
119136
return UAV;
120137

121138
case ResourceKind::ConstantBuffer:
@@ -124,6 +141,99 @@ DXResourceKind getDXKind(offloadtest::ResourceKind RK) {
124141
llvm_unreachable("All cases handled");
125142
}
126143

144+
static D3D12_RESOURCE_DESC getResourceDescription(const Resource &R) {
145+
const D3D12_RESOURCE_DIMENSION Dimension = getDXDimension(R.Kind);
146+
const offloadtest::Buffer &B = *R.BufferPtr;
147+
const DXGI_FORMAT Format =
148+
R.isTexture() ? getDXFormat(B.Format, B.Channels) : DXGI_FORMAT_UNKNOWN;
149+
const uint32_t Width =
150+
R.isTexture() ? B.OutputProps.Width : getUAVBufferSize(R);
151+
const uint32_t Height = R.isTexture() ? B.OutputProps.Height : 1;
152+
const D3D12_TEXTURE_LAYOUT Layout = R.isTexture()
153+
? D3D12_TEXTURE_LAYOUT_UNKNOWN
154+
: D3D12_TEXTURE_LAYOUT_ROW_MAJOR;
155+
const D3D12_RESOURCE_FLAGS Flags =
156+
R.isReadWrite() ? D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS
157+
: D3D12_RESOURCE_FLAG_NONE;
158+
const D3D12_RESOURCE_DESC ResDesc = {Dimension, 0, Width, Height, 1, 1,
159+
Format, {1, 0}, Layout, Flags};
160+
return ResDesc;
161+
}
162+
163+
static D3D12_SHADER_RESOURCE_VIEW_DESC getSRVDescription(const Resource &R) {
164+
const uint32_t EltSize = R.getElementSize();
165+
const uint32_t NumElts = R.size() / EltSize;
166+
167+
llvm::outs() << " EltSize = " << EltSize << " NumElts = " << NumElts
168+
<< "\n";
169+
D3D12_SHADER_RESOURCE_VIEW_DESC Desc = {};
170+
Desc.Format = R.isRaw()
171+
? getRawDXFormat(R)
172+
: getDXFormat(R.BufferPtr->Format, R.BufferPtr->Channels);
173+
Desc.Shader4ComponentMapping = D3D12_DEFAULT_SHADER_4_COMPONENT_MAPPING;
174+
switch (R.Kind) {
175+
case ResourceKind::Buffer:
176+
case ResourceKind::StructuredBuffer:
177+
case ResourceKind::ByteAddressBuffer:
178+
179+
Desc.ViewDimension = D3D12_SRV_DIMENSION_BUFFER;
180+
Desc.Buffer =
181+
D3D12_BUFFER_SRV{0, NumElts, R.isStructuredBuffer() ? EltSize : 0,
182+
R.isByteAddressBuffer() ? D3D12_BUFFER_SRV_FLAG_RAW
183+
: D3D12_BUFFER_SRV_FLAG_NONE};
184+
break;
185+
case ResourceKind::Texture2D:
186+
Desc.ViewDimension = D3D12_SRV_DIMENSION_TEXTURE2D;
187+
Desc.Texture2D = D3D12_TEX2D_SRV{0, 1, 0, 0};
188+
break;
189+
case ResourceKind::RWStructuredBuffer:
190+
case ResourceKind::RWBuffer:
191+
case ResourceKind::RWByteAddressBuffer:
192+
case ResourceKind::RWTexture2D:
193+
case ResourceKind::ConstantBuffer:
194+
llvm_unreachable("Not an SRV type!");
195+
}
196+
return Desc;
197+
}
198+
199+
static D3D12_UNORDERED_ACCESS_VIEW_DESC getUAVDescription(const Resource &R) {
200+
const uint32_t EltSize = R.getElementSize();
201+
const uint32_t NumElts = R.size() / EltSize;
202+
const uint32_t CounterOffset = getUAVBufferCounterOffset(R);
203+
204+
llvm::outs() << " EltSize = " << EltSize << " NumElts = " << NumElts
205+
<< "\n";
206+
D3D12_UNORDERED_ACCESS_VIEW_DESC Desc = {};
207+
Desc.Format = R.isRaw()
208+
? getRawDXFormat(R)
209+
: getDXFormat(R.BufferPtr->Format, R.BufferPtr->Channels);
210+
switch (R.Kind) {
211+
case ResourceKind::RWBuffer:
212+
case ResourceKind::RWStructuredBuffer:
213+
case ResourceKind::RWByteAddressBuffer:
214+
215+
Desc.ViewDimension = D3D12_UAV_DIMENSION_BUFFER;
216+
Desc.Buffer = D3D12_BUFFER_UAV{
217+
0, NumElts, R.isStructuredBuffer() ? EltSize : 0, CounterOffset,
218+
R.isByteAddressBuffer() ? D3D12_BUFFER_UAV_FLAG_RAW
219+
: D3D12_BUFFER_UAV_FLAG_NONE};
220+
break;
221+
case ResourceKind::RWTexture2D:
222+
Desc.ViewDimension = D3D12_UAV_DIMENSION_TEXTURE2D;
223+
Desc.Texture2D = D3D12_TEX2D_UAV{0, 0};
224+
break;
225+
case ResourceKind::StructuredBuffer:
226+
case ResourceKind::Buffer:
227+
case ResourceKind::ByteAddressBuffer:
228+
case ResourceKind::Texture2D:
229+
case ResourceKind::ConstantBuffer:
230+
llvm_unreachable("Not a UAV type!");
231+
}
232+
return Desc;
233+
}
234+
235+
namespace {
236+
127237
class DXDevice : public offloadtest::Device {
128238
private:
129239
ComPtr<IDXCoreAdapter> Adapter;
@@ -377,8 +487,20 @@ class DXDevice : public offloadtest::Device {
377487
ComPtr<ID3D12Resource> Destination,
378488
ComPtr<ID3D12Resource> Source) {
379489
addUploadBeginBarrier(IS, Destination);
380-
IS.CmdList->CopyBufferRegion(Destination.Get(), 0, Source.Get(), 0,
381-
R.size());
490+
if (R.isTexture()) {
491+
const offloadtest::Buffer &B = *R.BufferPtr;
492+
const D3D12_PLACED_SUBRESOURCE_FOOTPRINT Footprint{
493+
0, CD3DX12_SUBRESOURCE_FOOTPRINT(
494+
getDXFormat(B.Format, B.Channels), B.OutputProps.Width,
495+
B.OutputProps.Height, 1,
496+
B.OutputProps.Width * B.getElementSize())};
497+
const CD3DX12_TEXTURE_COPY_LOCATION DstLoc(Destination.Get(), 0);
498+
const CD3DX12_TEXTURE_COPY_LOCATION SrcLoc(Source.Get(), Footprint);
499+
500+
IS.CmdList->CopyTextureRegion(&DstLoc, 0, 0, 0, &SrcLoc, nullptr);
501+
} else
502+
IS.CmdList->CopyBufferRegion(Destination.Get(), 0, Source.Get(), 0,
503+
R.size());
382504
addUploadEndBarrier(IS, Destination, R.isReadWrite());
383505
}
384506

@@ -391,17 +513,7 @@ class DXDevice : public offloadtest::Device {
391513

392514
const D3D12_HEAP_PROPERTIES HeapProp =
393515
CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_DEFAULT);
394-
const D3D12_RESOURCE_DESC ResDesc = {
395-
D3D12_RESOURCE_DIMENSION_BUFFER,
396-
0,
397-
R.size(),
398-
1,
399-
1,
400-
1,
401-
DXGI_FORMAT_UNKNOWN,
402-
{1, 0},
403-
D3D12_TEXTURE_LAYOUT_ROW_MAJOR,
404-
D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS};
516+
const D3D12_RESOURCE_DESC ResDesc = getResourceDescription(R);
405517

406518
if (auto Err = HR::toError(Device->CreateCommittedResource(
407519
&HeapProp, D3D12_HEAP_FLAG_NONE, &ResDesc,
@@ -438,22 +550,9 @@ class DXDevice : public offloadtest::Device {
438550

439551
void bindSRV(Resource &R, InvocationState &IS, const uint32_t HeapIdx,
440552
ComPtr<ID3D12Resource> Buffer) {
441-
const uint32_t EltSize = R.getElementSize();
442-
const uint32_t NumElts = R.size() / EltSize;
443-
DXGI_FORMAT const EltFormat =
444-
R.isRaw() ? getRawDXFormat(R)
445-
: getDXFormat(R.BufferPtr->Format, R.BufferPtr->Channels);
446-
const D3D12_SHADER_RESOURCE_VIEW_DESC SRVDesc = {
447-
EltFormat,
448-
D3D12_SRV_DIMENSION_BUFFER,
449-
D3D12_DEFAULT_SHADER_4_COMPONENT_MAPPING,
450-
{D3D12_BUFFER_SRV{0, NumElts, R.isStructuredBuffer() ? EltSize : 0,
451-
R.isByteAddressBuffer()
452-
? D3D12_BUFFER_SRV_FLAG_RAW
453-
: D3D12_BUFFER_SRV_FLAG_NONE}}};
454-
455-
llvm::outs() << "SRV: HeapIdx = " << HeapIdx << " EltSize = " << EltSize
456-
<< " NumElts = " << NumElts << "\n";
553+
llvm::outs() << "SRV: HeapIdx = " << HeapIdx << "\n";
554+
const D3D12_SHADER_RESOURCE_VIEW_DESC SRVDesc = getSRVDescription(R);
555+
457556
D3D12_CPU_DESCRIPTOR_HANDLE SRVHandle =
458557
IS.DescHeap->GetCPUDescriptorHandleForHeapStart();
459558
SRVHandle.ptr += HeapIdx * Device->GetDescriptorHandleIncrementSize(
@@ -472,17 +571,7 @@ class DXDevice : public offloadtest::Device {
472571

473572
const D3D12_HEAP_PROPERTIES HeapProp =
474573
CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_DEFAULT);
475-
const D3D12_RESOURCE_DESC ResDesc = {
476-
D3D12_RESOURCE_DIMENSION_BUFFER,
477-
0,
478-
BufferSize,
479-
1,
480-
1,
481-
1,
482-
DXGI_FORMAT_UNKNOWN,
483-
{1, 0},
484-
D3D12_TEXTURE_LAYOUT_ROW_MAJOR,
485-
D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS};
574+
const D3D12_RESOURCE_DESC ResDesc = getResourceDescription(R);
486575

487576
if (auto Err = HR::toError(Device->CreateCommittedResource(
488577
&HeapProp, D3D12_HEAP_FLAG_NONE, &ResDesc,
@@ -541,24 +630,10 @@ class DXDevice : public offloadtest::Device {
541630

542631
void bindUAV(Resource &R, InvocationState &IS, const uint32_t HeapIdx,
543632
ComPtr<ID3D12Resource> Buffer) {
544-
const uint32_t EltSize = R.getElementSize();
545-
const uint32_t NumElts = R.size() / EltSize;
546633
ID3D12Resource *CounterBuffer = R.HasCounter ? Buffer.Get() : nullptr;
547-
const uint32_t CounterOffset = getUAVBufferCounterOffset(R);
548-
DXGI_FORMAT const EltFormat =
549-
R.isRaw() ? getRawDXFormat(R)
550-
: getDXFormat(R.BufferPtr->Format, R.BufferPtr->Channels);
551-
const D3D12_UNORDERED_ACCESS_VIEW_DESC UAVDesc = {
552-
EltFormat,
553-
D3D12_UAV_DIMENSION_BUFFER,
554-
{D3D12_BUFFER_UAV{
555-
0, NumElts, R.isStructuredBuffer() ? EltSize : 0, CounterOffset,
556-
R.isByteAddressBuffer() ? D3D12_BUFFER_UAV_FLAG_RAW
557-
: D3D12_BUFFER_UAV_FLAG_NONE}}};
558-
559-
llvm::outs() << "UAV: HeapIdx = " << HeapIdx << " EltSize = " << EltSize
560-
<< " NumElts = " << NumElts << " HasCounter = " << R.HasCounter
561-
<< "\n";
634+
llvm::outs() << "UAV: HeapIdx = " << HeapIdx << "\n";
635+
const D3D12_UNORDERED_ACCESS_VIEW_DESC UAVDesc = getUAVDescription(R);
636+
562637
D3D12_CPU_DESCRIPTOR_HANDLE UAVHandle =
563638
IS.DescHeap->GetCPUDescriptorHandleForHeapStart();
564639
UAVHandle.ptr += HeapIdx * Device->GetDescriptorHandleIncrementSize(
@@ -891,7 +966,21 @@ class DXDevice : public offloadtest::Device {
891966
if (R.second.Readback == nullptr)
892967
return;
893968
addReadbackBeginBarrier(IS, R.second.Buffer);
894-
IS.CmdList->CopyResource(R.second.Readback.Get(), R.second.Buffer.Get());
969+
if (R.first->isTexture()) {
970+
const offloadtest::Buffer &B = *R.first->BufferPtr;
971+
const D3D12_PLACED_SUBRESOURCE_FOOTPRINT Footprint{
972+
0, CD3DX12_SUBRESOURCE_FOOTPRINT(
973+
getDXFormat(B.Format, B.Channels), B.OutputProps.Width,
974+
B.OutputProps.Height, 1,
975+
B.OutputProps.Width * B.getElementSize())};
976+
const CD3DX12_TEXTURE_COPY_LOCATION DstLoc(R.second.Readback.Get(),
977+
Footprint);
978+
const CD3DX12_TEXTURE_COPY_LOCATION SrcLoc(R.second.Buffer.Get(), 0);
979+
980+
IS.CmdList->CopyTextureRegion(&DstLoc, 0, 0, 0, &SrcLoc, nullptr);
981+
} else
982+
IS.CmdList->CopyResource(R.second.Readback.Get(),
983+
R.second.Buffer.Get());
895984
addReadbackEndBarrier(IS, R.second.Buffer);
896985
};
897986

0 commit comments

Comments
 (0)