@@ -89,6 +89,19 @@ static DXGI_FORMAT getRawDXFormat(Resource &R) {
8989 return DXGI_FORMAT_UNKNOWN;
9090}
9191
92+ static uint32_t getUAVBufferSize (Resource &R) {
93+ return R.HasCounter
94+ ? llvm::alignTo (R.size (), D3D12_UAV_COUNTER_PLACEMENT_ALIGNMENT) +
95+ sizeof (uint32_t )
96+ : R.size ();
97+ }
98+
99+ static uint32_t getUAVBufferCounterOffset (Resource &R) {
100+ return R.HasCounter
101+ ? llvm::alignTo (R.size (), D3D12_UAV_COUNTER_PLACEMENT_ALIGNMENT)
102+ : 0 ;
103+ }
104+
92105namespace {
93106
94107enum DXResourceKind { UAV, SRV, CBV };
@@ -449,9 +462,10 @@ class DXDevice : public offloadtest::Device {
449462 }
450463
451464 llvm::Expected<ResourceSet> createUAV (Resource &R, InvocationState &IS) {
452- llvm::outs () << " Creating UAV: { Size = " << R.size () << " , Register = u"
465+ const uint32_t BufferSize = getUAVBufferSize (R);
466+ llvm::outs () << " Creating UAV: { Size = " << BufferSize << " , Register = u"
453467 << R.DXBinding .Register << " , Space = " << R.DXBinding .Space
454- << " }\n " ;
468+ << " , HasCounter = " << R. HasCounter << " }\n " ;
455469 ComPtr<ID3D12Resource> Buffer;
456470 ComPtr<ID3D12Resource> UploadBuffer;
457471 ComPtr<ID3D12Resource> ReadBackBuffer;
@@ -461,7 +475,7 @@ class DXDevice : public offloadtest::Device {
461475 const D3D12_RESOURCE_DESC ResDesc = {
462476 D3D12_RESOURCE_DIMENSION_BUFFER,
463477 0 ,
464- R. size () ,
478+ BufferSize ,
465479 1 ,
466480 1 ,
467481 1 ,
@@ -480,7 +494,7 @@ class DXDevice : public offloadtest::Device {
480494 const D3D12_HEAP_PROPERTIES UploadHeapProp =
481495 CD3DX12_HEAP_PROPERTIES (D3D12_HEAP_TYPE_UPLOAD);
482496 const D3D12_RESOURCE_DESC UploadResDesc =
483- CD3DX12_RESOURCE_DESC::Buffer (R. size () );
497+ CD3DX12_RESOURCE_DESC::Buffer (BufferSize );
484498
485499 if (auto Err =
486500 HR::toError (Device->CreateCommittedResource (
@@ -495,7 +509,7 @@ class DXDevice : public offloadtest::Device {
495509 const D3D12_RESOURCE_DESC ReadBackResDesc = {
496510 D3D12_RESOURCE_DIMENSION_BUFFER,
497511 0 ,
498- R. size () ,
512+ BufferSize ,
499513 1 ,
500514 1 ,
501515 1 ,
@@ -529,24 +543,27 @@ class DXDevice : public offloadtest::Device {
529543 ComPtr<ID3D12Resource> Buffer) {
530544 const uint32_t EltSize = R.getElementSize ();
531545 const uint32_t NumElts = R.size () / EltSize;
546+ ID3D12Resource *CounterBuffer = R.HasCounter ? Buffer.Get () : nullptr ;
547+ const uint32_t CounterOffset = getUAVBufferCounterOffset (R);
532548 DXGI_FORMAT const EltFormat =
533549 R.isRaw () ? getRawDXFormat (R)
534550 : getDXFormat (R.BufferPtr ->Format , R.BufferPtr ->Channels );
535551 const D3D12_UNORDERED_ACCESS_VIEW_DESC UAVDesc = {
536552 EltFormat,
537553 D3D12_UAV_DIMENSION_BUFFER,
538- {D3D12_BUFFER_UAV{0 , NumElts, R. isStructuredBuffer () ? EltSize : 0 , 0 ,
539- R. isByteAddressBuffer ()
540- ? D3D12_BUFFER_UAV_FLAG_RAW
541- : D3D12_BUFFER_UAV_FLAG_NONE}}};
554+ {D3D12_BUFFER_UAV{
555+ 0 , NumElts, R. isStructuredBuffer () ? EltSize : 0 , CounterOffset,
556+ R. isByteAddressBuffer () ? D3D12_BUFFER_UAV_FLAG_RAW
557+ : D3D12_BUFFER_UAV_FLAG_NONE}}};
542558
543559 llvm::outs () << " UAV: HeapIdx = " << HeapIdx << " EltSize = " << EltSize
544- << " NumElts = " << NumElts << " \n " ;
560+ << " NumElts = " << NumElts << " HasCounter = " << R.HasCounter
561+ << " \n " ;
545562 D3D12_CPU_DESCRIPTOR_HANDLE UAVHandle =
546563 IS.DescHeap ->GetCPUDescriptorHandleForHeapStart ();
547564 UAVHandle.ptr += HeapIdx * Device->GetDescriptorHandleIncrementSize (
548565 D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV);
549- Device->CreateUnorderedAccessView (Buffer.Get (), nullptr , &UAVDesc,
566+ Device->CreateUnorderedAccessView (Buffer.Get (), CounterBuffer , &UAVDesc,
550567 UAVHandle);
551568 }
552569
@@ -895,6 +912,11 @@ class DXDevice : public offloadtest::Device {
895912 " Failed to map result." ))
896913 return Err;
897914 memcpy (R.first ->BufferPtr ->Data .get (), DataPtr, R.first ->size ());
915+ if (R.first ->HasCounter )
916+ memcpy (&R.first ->BufferPtr ->Counter ,
917+ static_cast <char *>(DataPtr) +
918+ getUAVBufferCounterOffset (*R.first ),
919+ sizeof (uint32_t ));
898920 R.second .Readback ->Unmap (0 , nullptr );
899921 return llvm::Error::success ();
900922 };
0 commit comments