Skip to content

Commit 2fa6306

Browse files
committed
Clean up code
1 parent 65ff453 commit 2fa6306

File tree

3 files changed

+13
-87
lines changed

3 files changed

+13
-87
lines changed

Samples/WinMLSamplesGallery/WinMLSamplesGalleryNative/DXResourceBinding.cpp

Lines changed: 2 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
#include "DXResourceBinding.g.cpp"
77
#include "stdafx.h"
88
#include "ORTHelpers.h"
9-
#include "winrt/Microsoft.AI.MachineLearning.h"
10-
#include "winrt/Microsoft.AI.MachineLearning.Experimental.h"
119

1210
using namespace winrt::Microsoft::AI::MachineLearning;
1311

@@ -18,88 +16,10 @@ static std::optional<Ort::Session> preprocesingSession;
1816
static std::optional<Ort::Session> inferenceSession;
1917
D3D12Quad sample(800, 600, L"D3D12 Quad");
2018

21-
using TensorInt64Bit = winrt::Microsoft::AI::MachineLearning::TensorInt64Bit;
22-
using TensorKind = winrt::Microsoft::AI::MachineLearning::TensorKind;
23-
using LearningModelBuilder = winrt::Microsoft::AI::MachineLearning::Experimental::LearningModelBuilder;
24-
using LearningModelOperator = winrt::Microsoft::AI::MachineLearning::Experimental::LearningModelOperator;
25-
26-
std::array<int64_t, 2> preprocessInputShape;
27-
28-
std::array<long, 6> CalculateCenterFillDimensions(long oldH, long oldW, long h, long w)
29-
{
30-
long resizedW, resizedH, top, bottom, left, right;
31-
auto oldHFloat = (float)oldH;
32-
auto oldWFloat = (float)oldW;
33-
auto hFloat = (float)h;
34-
auto wFloat = (float)w;
35-
36-
auto oldAspectRatio = oldWFloat / oldHFloat;
37-
auto newAspectRatio = wFloat / hFloat;
38-
39-
auto scale = (newAspectRatio < oldAspectRatio) ? (hFloat / oldHFloat) : (wFloat / oldWFloat);
40-
resizedW = (newAspectRatio < oldAspectRatio) ? (long)std::floor(scale * oldWFloat) : w;
41-
resizedH = (newAspectRatio < oldAspectRatio) ? h : (long)std::floor(scale * oldHFloat);
42-
long totalPad = (newAspectRatio < oldAspectRatio) ? resizedW - w : resizedH - h;
43-
long biggerDim = (newAspectRatio < oldAspectRatio) ? w : h;
44-
long first = (totalPad % 2 == 0) ? totalPad / 2 : (long)std::floor(totalPad / 2.0f);
45-
long second = first + biggerDim;
46-
47-
if (newAspectRatio < oldAspectRatio)
48-
{
49-
top = 0;
50-
bottom = h;
51-
left = first;
52-
right = second;
53-
}
54-
else
55-
{
56-
top = first;
57-
bottom = second;
58-
left = 0;
59-
right = w;
60-
}
61-
62-
std::array<long, 6> new_dimensions = { resizedW, resizedH, top, bottom, left, right };
63-
return new_dimensions;
64-
}
65-
6619
namespace winrt::WinMLSamplesGalleryNative::implementation
6720
{
6821
// Create ORT Sessions and launch D3D window in a separate thread
6922
void DXResourceBinding::LaunchWindow() {
70-
71-
72-
long newH = 224;
73-
long newW = 224;
74-
long h = 512;
75-
long w = 512;
76-
std::array<long, 6> center_fill_dimensions = CalculateCenterFillDimensions(h, w, newH, newW);
77-
long resizedW = center_fill_dimensions[0];
78-
long resizedH = center_fill_dimensions[1];
79-
long top = center_fill_dimensions[2];
80-
long bottom = center_fill_dimensions[3];
81-
long left = center_fill_dimensions[4];
82-
long right = center_fill_dimensions[5];
83-
winrt::hstring interpolationMode = L"nearest";
84-
long c = 4;
85-
86-
auto width = 800;
87-
auto height = 600;
88-
auto rowPitchInBytes = (width * 4 + 255) & ~255;
89-
auto rowPitchInPixels = rowPitchInBytes / 4;
90-
auto bufferInBytes = rowPitchInBytes * height;
91-
preprocessInputShape = { 1, bufferInBytes};
92-
//const std::array<int64_t, 4> preprocessInputShape = { 1, 512, 512, 4 };
93-
//const std::array<int64_t, 4> preprocessInputShape = { 1, -1, -1, -1 };
94-
95-
const std::array<int64_t, 4> preprocessOutputShape = { 1, 224, 224, 3 };
96-
97-
auto kernel = new float[] {
98-
0,0,1,
99-
0,1,0,
100-
1,0,0
101-
};
102-
10323
// Create ORT Sessions that will be used for preprocessing and classification
10424
preprocesingSession = CreateSession(Win32Application::GetAssetPath(L"dx_preprocessor_efficient_net_v2.onnx").c_str());
10525
inferenceSession = CreateSession(Win32Application::GetAssetPath(L"efficientnet-lite4-11.onnx").c_str());
@@ -120,9 +40,9 @@ namespace winrt::WinMLSamplesGalleryNative::implementation
12040
// Get the buffer currently being drawn to the screen
12141
ComPtr<ID3D12Resource> currentBuffer = sample.GetCurrentBuffer();
12242

123-
// Preprocess the buffer (shrink from 512 x 512 x 4 to 224 x 224 x 3)
43+
// Preprocess the buffer (shrink to 224 x 224 x 3)
12444
Ort::Value preprocessedInput = Preprocess(*preprocesingSession,
125-
currentBuffer, preprocessInputShape);
45+
currentBuffer);
12646

12747
// Classify the image using EfficientNet and return the results
12848
winrt::com_array<float> results = Eval(*inferenceSession, preprocessedInput);

Samples/WinMLSamplesGallery/WinMLSamplesGalleryNative/ORTHelpers.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,7 @@ Ort::Session CreateSession(const wchar_t* model_file_path)
4747
// Run the buffer through a preprocessing model that will shrink the
4848
// image from 512 x 512 x 4 to 224 x 224 x 3
4949
Ort::Value Preprocess(Ort::Session& session,
50-
ComPtr<ID3D12Resource> currentBuffer,
51-
const std::array<int64_t, 2> inputShape)
50+
ComPtr<ID3D12Resource> currentBuffer)
5251
{
5352
// Init OrtAPI
5453
OrtApi const& ortApi = Ort::GetApi(); // Uses ORT_API_VERSION
@@ -59,7 +58,15 @@ Ort::Value Preprocess(Ort::Session& session,
5958
const char* memoryInformationName = "DML";
6059
Ort::MemoryInfo memoryInformation(memoryInformationName, OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault);
6160
ComPtr<IUnknown> inputTensorEpWrapper;
62-
//const std::array<int64_t, 4> inputShape = { 1, 512, 512, 4 };
61+
62+
// Calculate input shape
63+
auto width = 800;
64+
auto height = 600;
65+
auto rowPitchInBytes = (width * 4 + 255) & ~255;
66+
auto rowPitchInPixels = rowPitchInBytes / 4;
67+
auto bufferInBytes = rowPitchInBytes * height;
68+
const std::array<int64_t, 2> inputShape = { 1, bufferInBytes };
69+
6370
Ort::Value inputTensor = CreateTensorValueFromD3DResource(
6471
*ortDmlApi,
6572
memoryInformation,

Samples/WinMLSamplesGallery/WinMLSamplesGalleryNative/ORTHelpers.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@ using namespace Microsoft::WRL;
1313
Ort::Session CreateSession(const wchar_t* model_file_path);
1414

1515
Ort::Value Preprocess(Ort::Session& session,
16-
ComPtr<ID3D12Resource> currentBuffer,
17-
const std::array<int64_t, 2> inputShape);
16+
ComPtr<ID3D12Resource> currentBuffer);
1817

1918
winrt::com_array<float> Eval(Ort::Session& session, const Ort::Value& prev_input);
2019

0 commit comments

Comments
 (0)