Skip to content

Commit 4dbbb1c

Browse files
committed
Separate preprocess and Eval
1 parent 598fce1 commit 4dbbb1c

File tree

5 files changed

+14
-30
lines changed

5 files changed

+14
-30
lines changed

Samples/WinMLSamplesGallery/WinMLSamplesGalleryNative/DXResourceBinding.cpp

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ D3D12Quad sample(800, 600, L"D3D12 Quad");
3838

3939
namespace winrt::WinMLSamplesGalleryNative::implementation
4040
{
41-
winrt::com_array<float> DXResourceBinding::LaunchWindow() {
41+
void DXResourceBinding::LaunchWindow() {
4242

4343
const wchar_t* preprocessingModelFilePath = L"C:/Users/numform/Windows-Machine-Learning/Samples/WinMLSamplesGallery/WinMLSamplesGalleryNative/dx_preprocessor_efficient_net.onnx";
4444
preprocesingSession = CreateSession(preprocessingModelFilePath);
@@ -53,20 +53,15 @@ namespace winrt::WinMLSamplesGalleryNative::implementation
5353

5454
// Run the quad in a separate, detached thread
5555
d3d_th.detach();
56-
57-
winrt::com_array<float> eval_results(1000);
58-
for (int i = 0; i < 1000; i++) {
59-
eval_results[i] = 100;
60-
}
61-
return eval_results;
6256
}
6357

6458
winrt::com_array<float> DXResourceBinding::EvalORT() {
6559
D3D12Quad::D3DInfo info = sample.GetD3DInfo();
66-
bool running = true;
67-
return Preprocess(*preprocesingSession, *inferenceSession, info.device.Get(),
68-
running, info.swapChain.Get(), info.frameIndex, info.commandAllocator.Get(), info.commandList.Get(),
60+
Ort::Value preprocessedInput = Preprocess(*preprocesingSession, info.device.Get(),
61+
info.swapChain.Get(), info.frameIndex, info.commandAllocator.Get(), info.commandList.Get(),
6962
info.commandQueue.Get());
63+
winrt::com_array<float> results = Eval(*inferenceSession, preprocessedInput);
64+
return results;
7065
}
7166

7267
void DXResourceBinding::CloseWindow() {

Samples/WinMLSamplesGallery/WinMLSamplesGalleryNative/DXResourceBinding.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ namespace winrt::WinMLSamplesGalleryNative::implementation
66
struct DXResourceBinding : DXResourceBindingT<DXResourceBinding>
77
{
88
DXResourceBinding() = default;
9-
static winrt::com_array<float> LaunchWindow();
9+
static void LaunchWindow();
1010
static winrt::com_array<float> EvalORT();
1111
static void CloseWindow();
1212
};

Samples/WinMLSamplesGallery/WinMLSamplesGalleryNative/ORTHelpers.cpp

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ Ort::Value CreateTensorValueUsingD3DResource(
236236
);
237237
}
238238

239-
std::vector<float> Eval(Ort::Session& session, const Ort::Value& prev_input) {
239+
winrt::com_array<float> Eval(Ort::Session& session, const Ort::Value& prev_input) {
240240
const char* modelInputTensorName = "images:0";
241241
const char* modelOutputTensorName = "Softmax:0";
242242
const std::array<int64_t, 4> inputShape = { 1, 224, 224, 3 };
@@ -273,9 +273,9 @@ std::vector<float> Eval(Ort::Session& session, const Ort::Value& prev_input) {
273273
auto output_tensors = session.Run(Ort::RunOptions{ nullptr }, input_node_names.data(), &prev_input, 1, output_node_names.data(), 1);
274274
// Get pointer to output tensor float values
275275
float* floatarr = output_tensors.front().GetTensorMutableData<float>();
276-
std::vector<float> final_results;
276+
winrt::com_array<float> final_results(1000);
277277
for (int i = 0; i < 1000; i++) {
278-
final_results.push_back(floatarr[i]);
278+
final_results[i] = floatarr[i];
279279
}
280280

281281
return final_results;
@@ -290,13 +290,10 @@ std::vector<float> Eval(Ort::Session& session, const Ort::Value& prev_input) {
290290
printf("Error running model inference: %s\n", exception.what());
291291
//return EXIT_FAILURE;
292292
}
293-
294293
}
295294

296-
winrt::com_array<float> Preprocess(Ort::Session& session,
297-
Ort::Session& inferenceSession,
295+
Ort::Value Preprocess(Ort::Session& session,
298296
ID3D12Device* device,
299-
bool& Running,
300297
IDXGISwapChain3* swapChain,
301298
UINT frameIndex,
302299
ID3D12CommandAllocator* commandAllocator,
@@ -346,7 +343,6 @@ winrt::com_array<float> Preprocess(Ort::Session& session,
346343
IID_PPV_ARGS(&new_buffer));
347344
if (FAILED(hr))
348345
{
349-
Running = false;
350346
//return false;
351347
}
352348

@@ -398,12 +394,7 @@ winrt::com_array<float> Preprocess(Ort::Session& session,
398394
session.Run(Ort::RunOptions{ nullptr }, input_node_names.data(),
399395
&inputTensor, 1, output_node_names.data(), &outputTensor, 1);
400396

401-
auto eval_results_std = Eval(inferenceSession, outputTensor);
402-
winrt::com_array<float> eval_results(1000);
403-
for (int i = 0; i < 1000; i++) {
404-
eval_results[i] = eval_results_std[i];
405-
}
406-
return eval_results;
397+
return outputTensor;
407398
}
408399
catch (Ort::Exception const& exception)
409400
{

Samples/WinMLSamplesGallery/WinMLSamplesGalleryNative/ORTHelpers.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,10 @@ using namespace Microsoft::WRL;
3636

3737
Ort::Session CreateSession(const wchar_t* model_file_path);
3838

39-
std::vector<float> Eval(Ort::Session& session, const Ort::Value& prev_input);
39+
winrt::com_array<float> Eval(Ort::Session& session, const Ort::Value& prev_input);
4040

41-
winrt::com_array<float> Preprocess(Ort::Session& session,
42-
Ort::Session& inferenceSession,
41+
Ort::Value Preprocess(Ort::Session& session,
4342
ID3D12Device* device,
44-
bool& Running,
4543
IDXGISwapChain3* swapChain,
4644
UINT frameIndex,
4745
ID3D12CommandAllocator* commandAllocator,

Samples/WinMLSamplesGallery/WinMLSamplesGalleryNative/WinMLSamplesGalleryNative.idl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ namespace WinMLSamplesGalleryNative
2727
[default_interface]
2828
runtimeclass DXResourceBinding
2929
{
30-
static float[] LaunchWindow();
30+
static void LaunchWindow();
3131
static float[] EvalORT();
3232
static void CloseWindow();
3333
}

0 commit comments

Comments
 (0)