Skip to content

Commit 7e81ec5

Browse files
refactor code
1 parent 6889310 commit 7e81ec5

File tree

3 files changed

+116
-90
lines changed

3 files changed

+116
-90
lines changed

Samples/BatchSupport/BatchSupport/SampleHelper.cpp

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
#include "SampleHelper.h"
21
#include "pch.h"
2+
#include "SampleHelper.h"
33

44
#include "Windows.AI.MachineLearning.Native.h"
55
#include <MemoryBuffer.h>
@@ -17,6 +17,8 @@ using namespace winrt::Windows::Storage;
1717

1818
EXTERN_C IMAGE_DOS_HEADER __ImageBase;
1919

20+
#define BATCH_SIZE 3
21+
2022
namespace SampleHelper {
2123
std::wstring GetModulePath() {
2224
std::wstring val;
@@ -103,9 +105,7 @@ SoftwareBitmapToSoftwareTensor(SoftwareBitmap softwareBitmap) {
103105
}
104106

105107
VideoFrame LoadImageFile(hstring filePath) {
106-
DWORD ticks = GetTickCount();
107108
VideoFrame inputImage = nullptr;
108-
109109
try {
110110
// open the file
111111
StorageFile file = StorageFile::GetFileFromPathAsync(filePath).get();
@@ -122,13 +122,52 @@ VideoFrame LoadImageFile(hstring filePath) {
122122
"qualified paths\r\n");
123123
exit(EXIT_FAILURE);
124124
}
125-
126-
ticks = GetTickCount() - ticks;
127-
printf("image file loaded in %d ticks\n", ticks);
128125
// all done
129126
return inputImage;
130127
}
131128

129+
hstring GetModelPath(std::string modelType) {
130+
hstring modelPath;
131+
if (modelType == "fixedBatchSize") {
132+
modelPath =
133+
static_cast<hstring>(GetModulePath().c_str()) + L"SqueezeNet_free.onnx";
134+
} else {
135+
modelPath =
136+
static_cast<hstring>(GetModulePath().c_str()) + L"SqueezeNet.onnx";
137+
}
138+
return modelPath;
139+
}
140+
141+
TensorFloat CreateInputTensorFloat() {
142+
std::vector<hstring> imageNames = {L"fish.png", L"kitten_224.png", L"fish.png"};
143+
std::vector<float> inputVector = {};
144+
for (hstring imageName : imageNames) {
145+
auto imagePath = static_cast<hstring>(GetModulePath().c_str()) + imageName;
146+
auto imageFrame = LoadImageFile(imagePath);
147+
std::vector<float> imageVector =
148+
SoftwareBitmapToSoftwareTensor(imageFrame.SoftwareBitmap());
149+
inputVector.insert(inputVector.end(), imageVector.begin(), imageVector.end());
150+
}
151+
auto inputShape = std::vector<int64_t>{ BATCH_SIZE, 3, 224, 224 };
152+
auto inputValue = TensorFloat::CreateFromIterable(
153+
inputShape,
154+
single_threaded_vector<float>(std::move(inputVector)).GetView());
155+
156+
return inputValue;
157+
}
158+
159+
IVector<VideoFrame> CreateVideoFrames() {
160+
std::vector<hstring> imageNames = { L"fish.png", L"kitten_224.png", L"fish.png" };
161+
std::vector<VideoFrame> inputFrames = {};
162+
for (hstring imageName : imageNames) {
163+
auto imagePath = static_cast<hstring>(GetModulePath().c_str()) + imageName;
164+
auto imageFrame = LoadImageFile(imagePath);
165+
inputFrames.emplace_back(imageFrame);
166+
}
167+
auto videoFrames = winrt::single_threaded_vector(std::move(inputFrames));
168+
return videoFrames;
169+
}
170+
132171
std::vector<std::string> LoadLabels(std::string labelsFilePath) {
133172
// Parse labels from labels file. We know the file's entries are already
134173
// sorted in order.
@@ -162,15 +201,15 @@ void PrintResults(IVectorView<float> results) {
162201
std::vector<std::string> labels = LoadLabels(labelsFilePath);
163202
// SqueezeNet returns a list of 1000 options, with probabilities for each,
164203
// loop through all
165-
for (uint32_t batchId = 0; batchId < 3; ++batchId) {
204+
for (uint32_t batchId = 0; batchId < BATCH_SIZE; ++batchId) {
166205
// Find the top probability
167206
float topProbability = 0;
168207
int topProbabilityLabelIndex;
169-
uint32_t oneOutputSize = results.Size() / 3;
208+
uint32_t oneOutputSize = results.Size() / BATCH_SIZE;
170209
for (uint32_t i = 0; i < oneOutputSize; i++) {
171-
if (results.GetAt(i + oneOutputSize) > topProbability) {
210+
if (results.GetAt(i + oneOutputSize * batchId) > topProbability) {
172211
topProbabilityLabelIndex = i;
173-
topProbability = results.GetAt(i + oneOutputSize);
212+
topProbability = results.GetAt(i + oneOutputSize * batchId);
174213
}
175214
}
176215
// Display the result

Samples/BatchSupport/BatchSupport/SampleHelper.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,14 @@ namespace SampleHelper
2121
// Load object detection labels
2222
std::vector<std::string> LoadLabels(std::string labelsFilePath);
2323

24+
// Create input Tensorfloats with 3 images.
25+
winrt::Windows::AI::MachineLearning::TensorFloat CreateInputTensorFloat();
26+
27+
// Create input VideoFrames with 3 images
28+
winrt::Windows::Foundation::Collections::IVector<winrt::Windows::Media::VideoFrame> CreateVideoFrames();
29+
30+
winrt::hstring GetModelPath(std::string modelType);
31+
2432
void PrintResults(winrt::Windows::Foundation::Collections::IVectorView<float> results);
2533

2634
}

Samples/BatchSupport/BatchSupport/main.cpp

Lines changed: 59 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -12,100 +12,79 @@ using namespace winrt::Windows::Storage::Streams;
1212
using namespace std;
1313

1414
#define BATCH_SIZE 3
15-
hstring executionPath = static_cast<hstring>(SampleHelper::GetModulePath().c_str());
1615
string modelType = "freeBatchSize";
17-
bool ParseArgs(int argc, char* argv[]);
16+
string inputType = "TensorFloat";
1817

19-
int main(int argc, char* argv[])
20-
{
21-
init_apartment();
18+
hstring executionPath =
19+
static_cast<hstring>(SampleHelper::GetModulePath().c_str());
2220

23-
// did they pass in the args
24-
if (ParseArgs(argc, argv) == false)
25-
{
26-
printf("Usage: %s [fixedBatchSize|freeBatchSize]", argv[0]);
27-
}
21+
bool ParseArgs(int argc, char *argv[]);
2822

29-
// Get model path and image path
30-
hstring modelPath;
31-
if (modelType == "fixedBatchSize") {
32-
modelPath = executionPath + L"SqueezeNet_free.onnx";
33-
}
34-
else {
35-
modelPath = executionPath + L"SqueezeNet.onnx";
36-
}
37-
auto imagePath = executionPath + L"kitten_224.png";
23+
int main(int argc, char *argv[]) {
24+
init_apartment();
3825

39-
// load the model
40-
printf("Loading modelfile '%ws' on the CPU\n", modelPath.c_str());
41-
DWORD ticks = GetTickCount();
42-
auto model = LearningModel::LoadFromFilePath(modelPath);
43-
ticks = GetTickCount() - ticks;
44-
printf("model file loaded in %d ticks\n", ticks);
26+
// did they pass in the args
27+
if (ParseArgs(argc, argv) == false) {
28+
printf("Usage: %s [fixedBatchSize|freeBatchSize] [TensorFloat|VideoFrame] \n", argv[0]);
29+
}
4530

46-
// load the image
47-
printf("Loading the image...\n");
48-
auto imageFrame = SampleHelper::LoadImageFile(imagePath);
49-
50-
// Create input Tensorfloats with 3 copied tensors.
51-
std::vector<float> imageVector = SampleHelper::SoftwareBitmapToSoftwareTensor(imageFrame.SoftwareBitmap());
52-
std::vector<float> inputVector = imageVector;
53-
inputVector.insert(inputVector.end(), imageVector.begin(), imageVector.end());
54-
inputVector.insert(inputVector.end(), imageVector.begin(), imageVector.end());
31+
// load the model
32+
hstring modelPath = SampleHelper::GetModelPath(modelType);
33+
printf("Loading modelfile '%ws' on the CPU\n", modelPath.c_str());
34+
DWORD ticks = GetTickCount();
35+
auto model = LearningModel::LoadFromFilePath(modelPath);
36+
ticks = GetTickCount() - ticks;
37+
printf("model file loaded in %d ticks\n", ticks);
38+
39+
// now create a session and binding
40+
LearningModelDeviceKind deviceKind = LearningModelDeviceKind::Cpu;
41+
LearningModelSessionOptions options;
42+
if ("freeBatchSize" == modelType) {
43+
// If the model has free dimentional batch, override the free dimension with batch_size
44+
options.BatchSizeOverride(static_cast<uint32_t>(BATCH_SIZE));
45+
}
46+
LearningModelSession session(model, LearningModelDevice(deviceKind), options);
47+
LearningModelBinding binding(session);
5548

56-
auto inputShape = std::vector<int64_t>{ BATCH_SIZE, 3, 224, 224 };
57-
auto inputValue =
58-
TensorFloat::CreateFromIterable(
59-
inputShape,
60-
single_threaded_vector<float>(std::move(inputVector)).GetView());
49+
// bind the intput image
50+
printf("Binding...\n");
51+
auto inputFeatureDescriptor = model.InputFeatures().First();
6152

53+
if (inputType == "TensorFloat") { // if bind TensorFloat
54+
// Create input Tensorfloats with 3 copied tensors.
55+
TensorFloat inputTensorValue = SampleHelper::CreateInputTensorFloat();
56+
binding.Bind(inputFeatureDescriptor.Current().Name(), inputTensorValue);
57+
} else { // else bind VideoFrames
6258
// Create input VideoFrames with 3 copied images
63-
vector<VideoFrame> inputFrames = {};
64-
for (uint32_t i = 0; i < BATCH_SIZE; ++i) {
65-
inputFrames.emplace_back(imageFrame);
66-
}
67-
auto videoFrames = winrt::single_threaded_vector(move(inputFrames));
68-
69-
// now create a session and binding
70-
LearningModelDeviceKind deviceKind = LearningModelDeviceKind::Cpu;
71-
72-
LearningModelSessionOptions options;
73-
if ("freeBatchSize" == modelType) {
74-
options.BatchSizeOverride(static_cast<uint32_t>(BATCH_SIZE));
75-
}
76-
LearningModelSession session(model, LearningModelDevice(deviceKind), options);
77-
LearningModelBinding binding(session);
78-
79-
// bind the intput image
80-
printf("Binding...\n");
81-
auto inputFeatureDescriptor = model.InputFeatures().First();
82-
binding.Bind(inputFeatureDescriptor.Current().Name(), videoFrames);
83-
84-
// bind output tensor
85-
auto outputShape = std::vector<int64_t>{ BATCH_SIZE, 1000, 1, 1};
86-
auto outputValue = TensorFloat::Create(outputShape);
87-
std::wstring outputDataBindingName = std::wstring(model.OutputFeatures().First().Current().Name());
88-
binding.Bind(outputDataBindingName, outputValue);
89-
90-
91-
// bind output videoFrames
92-
// now run the model
93-
printf("Running the model...\n");
94-
ticks = GetTickCount();
95-
auto results = session.EvaluateAsync(binding, L"RunId").get();
96-
ticks = GetTickCount() - ticks;
97-
printf("model run took %d ticks\n", ticks);
98-
99-
SampleHelper::PrintResults(outputValue.GetAsVectorView());
59+
auto inputVideoFrames = SampleHelper::CreateVideoFrames();
60+
binding.Bind(inputFeatureDescriptor.Current().Name(), inputVideoFrames);
61+
}
10062

63+
// bind output tensor
64+
auto outputShape = std::vector<int64_t>{BATCH_SIZE, 1000, 1, 1};
65+
auto outputValue = TensorFloat::Create(outputShape);
66+
std::wstring outputDataBindingName =
67+
std::wstring(model.OutputFeatures().First().Current().Name());
68+
binding.Bind(outputDataBindingName, outputValue);
69+
70+
// now run the model
71+
printf("Running the model...\n");
72+
ticks = GetTickCount();
73+
auto results = session.EvaluateAsync(binding, L"RunId").get();
74+
ticks = GetTickCount() - ticks;
75+
printf("model run took %d ticks\n", ticks);
76+
77+
// Print Results
78+
SampleHelper::PrintResults(outputValue.GetAsVectorView());
10179
}
10280

103-
bool ParseArgs(int argc, char* argv[])
104-
{
105-
if (argc < 2)
106-
{
81+
bool ParseArgs(int argc, char *argv[]) {
82+
if (argc < 2) {
10783
return false;
10884
}
10985
modelType = argv[1];
86+
if (argc > 3) {
87+
inputType = argv[2];
88+
}
11089
return true;
11190
}

0 commit comments

Comments
 (0)