@@ -12,100 +12,79 @@ using namespace winrt::Windows::Storage::Streams;
1212using namespace std ;
1313
1414#define BATCH_SIZE 3
15- hstring executionPath = static_cast <hstring>(SampleHelper::GetModulePath().c_str());
1615string modelType = " freeBatchSize" ;
17- bool ParseArgs ( int argc, char * argv[]) ;
16+ string inputType = " TensorFloat " ;
1817
19- int main (int argc, char * argv[])
20- {
21- init_apartment ();
18+ hstring executionPath =
19+ static_cast <hstring>(SampleHelper::GetModulePath().c_str());
2220
23- // did they pass in the args
24- if (ParseArgs (argc, argv) == false )
25- {
26- printf (" Usage: %s [fixedBatchSize|freeBatchSize]" , argv[0 ]);
27- }
21+ bool ParseArgs (int argc, char *argv[]);
2822
29- // Get model path and image path
30- hstring modelPath;
31- if (modelType == " fixedBatchSize" ) {
32- modelPath = executionPath + L" SqueezeNet_free.onnx" ;
33- }
34- else {
35- modelPath = executionPath + L" SqueezeNet.onnx" ;
36- }
37- auto imagePath = executionPath + L" kitten_224.png" ;
23+ int main (int argc, char *argv[]) {
24+ init_apartment ();
3825
39- // load the model
40- printf (" Loading modelfile '%ws' on the CPU\n " , modelPath.c_str ());
41- DWORD ticks = GetTickCount ();
42- auto model = LearningModel::LoadFromFilePath (modelPath);
43- ticks = GetTickCount () - ticks;
44- printf (" model file loaded in %d ticks\n " , ticks);
26+ // did they pass in the args
27+ if (ParseArgs (argc, argv) == false ) {
28+ printf (" Usage: %s [fixedBatchSize|freeBatchSize] [TensorFloat|VideoFrame] \n " , argv[0 ]);
29+ }
4530
46- // load the image
47- printf (" Loading the image...\n " );
48- auto imageFrame = SampleHelper::LoadImageFile (imagePath);
49-
50- // Create input Tensorfloats with 3 copied tensors.
51- std::vector<float > imageVector = SampleHelper::SoftwareBitmapToSoftwareTensor (imageFrame.SoftwareBitmap ());
52- std::vector<float > inputVector = imageVector;
53- inputVector.insert (inputVector.end (), imageVector.begin (), imageVector.end ());
54- inputVector.insert (inputVector.end (), imageVector.begin (), imageVector.end ());
31+ // load the model
32+ hstring modelPath = SampleHelper::GetModelPath (modelType);
33+ printf (" Loading modelfile '%ws' on the CPU\n " , modelPath.c_str ());
34+ DWORD ticks = GetTickCount ();
35+ auto model = LearningModel::LoadFromFilePath (modelPath);
36+ ticks = GetTickCount () - ticks;
37+ printf (" model file loaded in %d ticks\n " , ticks);
38+
39+ // now create a session and binding
40+ LearningModelDeviceKind deviceKind = LearningModelDeviceKind::Cpu;
41+ LearningModelSessionOptions options;
42+ if (" freeBatchSize" == modelType) {
43+ // If the model has free dimentional batch, override the free dimension with batch_size
44+ options.BatchSizeOverride (static_cast <uint32_t >(BATCH_SIZE));
45+ }
46+ LearningModelSession session (model, LearningModelDevice (deviceKind), options);
47+ LearningModelBinding binding (session);
5548
56- auto inputShape = std::vector<int64_t >{ BATCH_SIZE, 3 , 224 , 224 };
57- auto inputValue =
58- TensorFloat::CreateFromIterable (
59- inputShape,
60- single_threaded_vector<float >(std::move (inputVector)).GetView ());
49+ // bind the intput image
50+ printf (" Binding...\n " );
51+ auto inputFeatureDescriptor = model.InputFeatures ().First ();
6152
53+ if (inputType == " TensorFloat" ) { // if bind TensorFloat
54+ // Create input Tensorfloats with 3 copied tensors.
55+ TensorFloat inputTensorValue = SampleHelper::CreateInputTensorFloat ();
56+ binding.Bind (inputFeatureDescriptor.Current ().Name (), inputTensorValue);
57+ } else { // else bind VideoFrames
6258 // Create input VideoFrames with 3 copied images
63- vector<VideoFrame> inputFrames = {};
64- for (uint32_t i = 0 ; i < BATCH_SIZE; ++i) {
65- inputFrames.emplace_back (imageFrame);
66- }
67- auto videoFrames = winrt::single_threaded_vector (move (inputFrames));
68-
69- // now create a session and binding
70- LearningModelDeviceKind deviceKind = LearningModelDeviceKind::Cpu;
71-
72- LearningModelSessionOptions options;
73- if (" freeBatchSize" == modelType) {
74- options.BatchSizeOverride (static_cast <uint32_t >(BATCH_SIZE));
75- }
76- LearningModelSession session (model, LearningModelDevice (deviceKind), options);
77- LearningModelBinding binding (session);
78-
79- // bind the intput image
80- printf (" Binding...\n " );
81- auto inputFeatureDescriptor = model.InputFeatures ().First ();
82- binding.Bind (inputFeatureDescriptor.Current ().Name (), videoFrames);
83-
84- // bind output tensor
85- auto outputShape = std::vector<int64_t >{ BATCH_SIZE, 1000 , 1 , 1 };
86- auto outputValue = TensorFloat::Create (outputShape);
87- std::wstring outputDataBindingName = std::wstring (model.OutputFeatures ().First ().Current ().Name ());
88- binding.Bind (outputDataBindingName, outputValue);
89-
90-
91- // bind output videoFrames
92- // now run the model
93- printf (" Running the model...\n " );
94- ticks = GetTickCount ();
95- auto results = session.EvaluateAsync (binding, L" RunId" ).get ();
96- ticks = GetTickCount () - ticks;
97- printf (" model run took %d ticks\n " , ticks);
98-
99- SampleHelper::PrintResults (outputValue.GetAsVectorView ());
59+ auto inputVideoFrames = SampleHelper::CreateVideoFrames ();
60+ binding.Bind (inputFeatureDescriptor.Current ().Name (), inputVideoFrames);
61+ }
10062
63+ // bind output tensor
64+ auto outputShape = std::vector<int64_t >{BATCH_SIZE, 1000 , 1 , 1 };
65+ auto outputValue = TensorFloat::Create (outputShape);
66+ std::wstring outputDataBindingName =
67+ std::wstring (model.OutputFeatures ().First ().Current ().Name ());
68+ binding.Bind (outputDataBindingName, outputValue);
69+
70+ // now run the model
71+ printf (" Running the model...\n " );
72+ ticks = GetTickCount ();
73+ auto results = session.EvaluateAsync (binding, L" RunId" ).get ();
74+ ticks = GetTickCount () - ticks;
75+ printf (" model run took %d ticks\n " , ticks);
76+
77+ // Print Results
78+ SampleHelper::PrintResults (outputValue.GetAsVectorView ());
10179}
10280
103- bool ParseArgs (int argc, char * argv[])
104- {
105- if (argc < 2 )
106- {
81+ bool ParseArgs (int argc, char *argv[]) {
82+ if (argc < 2 ) {
10783 return false ;
10884 }
10985 modelType = argv[1 ];
86+ if (argc > 3 ) {
87+ inputType = argv[2 ];
88+ }
11089 return true ;
11190}
0 commit comments