Skip to content

Commit 3a44c4d

Browse files
committed
Fix preprocessing
1 parent 5ac01e0 commit 3a44c4d

File tree

5 files changed

+18
-18
lines changed

5 files changed

+18
-18
lines changed

Samples/WinMLSamplesGallery/WinMLSamplesGalleryNative/D3D12Quad.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -157,13 +157,13 @@ void D3D12Quad::LoadAssets()
157157
m_commandList->ResourceBarrier(1, &present_to_copy_src);
158158

159159
auto desc = m_renderTargets[m_frameIndex]->GetDesc();
160-
UINT64 rowSize, totalSize;
161-
m_device->GetCopyableFootprints(&desc, 0, 1, 0, nullptr, nullptr, &rowSize, &totalSize);
160+
UINT64 rowSizeInBytes, totalSizeInBytes;
161+
m_device->GetCopyableFootprints(&desc, 0, 1, 0, nullptr, nullptr, &rowSizeInBytes, &totalSizeInBytes);
162162
D3D12_PLACED_SUBRESOURCE_FOOTPRINT bufferFootprint = {};
163163
bufferFootprint.Footprint.Width = static_cast<UINT>(desc.Width);
164164
bufferFootprint.Footprint.Height = desc.Height;
165165
bufferFootprint.Footprint.Depth = 1;
166-
bufferFootprint.Footprint.RowPitch = static_cast<UINT>((rowSize + 255) & ~255);
166+
bufferFootprint.Footprint.RowPitch = static_cast<UINT>((rowSizeInBytes + 255) & ~255);
167167
bufferFootprint.Footprint.Format = desc.Format;
168168

169169
const CD3DX12_TEXTURE_COPY_LOCATION copyDest(currentBuffer.Get(), bufferFootprint);
@@ -395,19 +395,19 @@ void D3D12Quad::CreateCurrentBuffer()
395395
D3D12_RESOURCE_DESC bufferDesc = {};
396396
bufferDesc.DepthOrArraySize = 1;
397397
bufferDesc.Dimension = D3D12_RESOURCE_DIMENSION_BUFFER;
398-
bufferDesc.Flags = D3D12_RESOURCE_FLAG_NONE;
398+
bufferDesc.Flags = D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS;
399399
bufferDesc.Format = DXGI_FORMAT_UNKNOWN;
400400
//bufferDesc.Format = DXGI_FORMAT_R8G8B8A8_UNORM;
401401
bufferDesc.Height = 1;
402-
bufferDesc.Width = ((desc.Width + 255) & ~255) * 4 * desc.Height;
402+
bufferDesc.Width = ((desc.Width* 4 + 255) & ~255) * desc.Height;
403403
//bufferDesc.Width = static_cast<uint64_t>(800 * 600 * 3 * 4);
404404
bufferDesc.Layout = D3D12_TEXTURE_LAYOUT_ROW_MAJOR;
405405
//bufferDesc.Layout = D3D12_TEXTURE_LAYOUT_UNKNOWN;
406406
bufferDesc.MipLevels = 1;
407407
bufferDesc.SampleDesc.Count = 1;
408408

409409
//auto heap_properties = CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_DEFAULT);
410-
const CD3DX12_HEAP_PROPERTIES readBackHeapProperties(D3D12_HEAP_TYPE_READBACK);
410+
const CD3DX12_HEAP_PROPERTIES readBackHeapProperties(D3D12_HEAP_TYPE_DEFAULT);
411411
auto some_hr = m_device->CreateCommittedResource(
412412
&readBackHeapProperties,
413413
D3D12_HEAP_FLAG_NONE,
@@ -535,13 +535,13 @@ void D3D12Quad::PopulateCommandList()
535535
//m_commandList->CopyResource(currentBuffer.Get(), m_renderTargets[m_frameIndex].Get());
536536

537537
auto desc = m_renderTargets[m_frameIndex]->GetDesc();
538-
UINT64 rowSize, totalSize;
539-
m_device->GetCopyableFootprints(&desc, 0, 1, 0, nullptr, nullptr, &rowSize, &totalSize);
538+
UINT64 rowSizeInBytes, totalSizeInBytes;
539+
m_device->GetCopyableFootprints(&desc, 0, 1, 0, nullptr, nullptr, &rowSizeInBytes, &totalSizeInBytes);
540540
D3D12_PLACED_SUBRESOURCE_FOOTPRINT bufferFootprint = {};
541541
bufferFootprint.Footprint.Width = desc.Width;
542542
bufferFootprint.Footprint.Height = desc.Height;
543543
bufferFootprint.Footprint.Depth = 1;
544-
bufferFootprint.Footprint.RowPitch = static_cast<UINT>((rowSize + 255) & ~255);
544+
bufferFootprint.Footprint.RowPitch = static_cast<UINT>((rowSizeInBytes + 255) & ~255);
545545
bufferFootprint.Footprint.Format = desc.Format;
546546

547547
const CD3DX12_TEXTURE_COPY_LOCATION copyDest(currentBuffer.Get(), bufferFootprint);

Samples/WinMLSamplesGallery/WinMLSamplesGalleryNative/DXResourceBinding.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ using TensorKind = winrt::Microsoft::AI::MachineLearning::TensorKind;
2323
using LearningModelBuilder = winrt::Microsoft::AI::MachineLearning::Experimental::LearningModelBuilder;
2424
using LearningModelOperator = winrt::Microsoft::AI::MachineLearning::Experimental::LearningModelOperator;
2525

26-
std::array<int64_t, 4> preprocessInputShape;
26+
std::array<int64_t, 2> preprocessInputShape;
2727

2828
std::array<long, 6> CalculateCenterFillDimensions(long oldH, long oldW, long h, long w)
2929
{
@@ -81,12 +81,12 @@ namespace winrt::WinMLSamplesGalleryNative::implementation
8181
long left = center_fill_dimensions[4];
8282
long right = center_fill_dimensions[5];
8383
winrt::hstring interpolationMode = L"nearest";
84-
long c = 3;
84+
long c = 4;
8585

8686
auto width = 800;
8787
auto height = 600;
88-
auto rowPitchInPixels = (width + 255) & ~255;
89-
auto rowPitchInBytes = rowPitchInPixels * 4;
88+
auto rowPitchInBytes = (width * 4 + 255) & ~255;
89+
auto rowPitchInPixels = rowPitchInBytes / 4;
9090
auto bufferInBytes = rowPitchInBytes * height;
9191
preprocessInputShape = { 1, bufferInBytes};
9292
//const std::array<int64_t, 4> preprocessInputShape = { 1, 512, 512, 4 };
@@ -158,7 +158,7 @@ namespace winrt::WinMLSamplesGalleryNative::implementation
158158

159159
auto reshape_op = LearningModelOperator(L"Reshape")
160160
.SetInput(L"data", L"CastOutput")
161-
.SetConstant(L"shape", TensorInt64Bit::CreateFromIterable({ 4 }, { 1, 512, 512, 4 }))
161+
.SetConstant(L"shape", TensorInt64Bit::CreateFromIterable({ 4 }, { 1, height, 832, 4 }))
162162
.SetOutput(L"reshaped", L"ReshapeOutput");
163163

164164
auto slice_1 = LearningModelOperator(L"Slice")
@@ -182,7 +182,7 @@ namespace winrt::WinMLSamplesGalleryNative::implementation
182182

183183
auto preprocessingModelBuilder =
184184
LearningModelBuilder::Create(12)
185-
.Inputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"Input", TensorKind::Float, preprocessInputShape))
185+
.Inputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"Input", TensorKind::UInt8, preprocessInputShape))
186186
.Outputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"Output", TensorKind::Float, preprocessOutputShape))
187187
.Operators().Add(cast_op)
188188
.Operators().Add(reshape_op)

Samples/WinMLSamplesGallery/WinMLSamplesGalleryNative/ORTHelpers.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ Ort::Session CreateSession(const wchar_t* model_file_path)
4848
// image from 512 x 512 x 4 to 224 x 224 x 3
4949
Ort::Value Preprocess(Ort::Session& session,
5050
ComPtr<ID3D12Resource> currentBuffer,
51-
const std::array<int64_t, 4> inputShape)
51+
const std::array<int64_t, 2> inputShape)
5252
{
5353
// Init OrtAPI
5454
OrtApi const& ortApi = Ort::GetApi(); // Uses ORT_API_VERSION
@@ -65,7 +65,7 @@ Ort::Value Preprocess(Ort::Session& session,
6565
memoryInformation,
6666
currentBuffer.Get(),
6767
inputShape,
68-
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
68+
ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8,
6969
/*out*/ IID_PPV_ARGS_Helper(inputTensorEpWrapper.GetAddressOf())
7070
);
7171

Samples/WinMLSamplesGallery/WinMLSamplesGalleryNative/ORTHelpers.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ Ort::Session CreateSession(const wchar_t* model_file_path);
1414

1515
Ort::Value Preprocess(Ort::Session& session,
1616
ComPtr<ID3D12Resource> currentBuffer,
17-
const std::array<int64_t, 4> inputShape);
17+
const std::array<int64_t, 2> inputShape);
1818

1919
winrt::com_array<float> Eval(Ort::Session& session, const Ort::Value& prev_input);
2020

Binary file not shown.

0 commit comments

Comments
 (0)