Skip to content

Commit ad21443

Browse files
author
Ryan Lai
authored
Add functions to allow ILearningModelFeatureValues to be plumbed through as input through WinMLRunner static library (#327)
* Make changes to allow protobuf input through commandline args * use protobuf helper * Use feature values instead of protobuf * Commandline args changes * added clear method * Naming changes * Formatting * DOn't modify git modules
1 parent 283c622 commit ad21443

File tree

4 files changed

+27
-9
lines changed

4 files changed

+27
-9
lines changed

Tools/WinMLRunner/onnxruntime

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Subproject commit ccbf49e59f6bef897a94595e2213263b37d64ff3

Tools/WinMLRunner/src/CommandLineArgs.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,3 +600,7 @@ void CommandLineArgs::AddPerformanceFileMetadata(const std::string& key, const s
600600
cleanedValue.erase(std::remove_copy(value.begin(), value.end(), cleanedValue.begin(), ','), cleanedValue.end());
601601
m_perfFileMetadata.push_back(std::make_pair(key, cleanedValue));
602602
}
603+
void CommandLineArgs::AddProvidedInputFeatureValue(const ILearningModelFeatureValue& input)
604+
{
605+
m_providedInputFeatureValues.push_back(input);
606+
}

Tools/WinMLRunner/src/CommandLineArgs.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ class CommandLineArgs
4545
BitmapInterpolationMode AutoScaleInterpMode() const { return m_autoScaleInterpMode; }
4646

4747
const std::vector<std::wstring>& ImagePaths() const { return m_imagePaths; }
48+
const std::vector<ILearningModelFeatureValue>& ProvidedInputFeatureValues() const
49+
{
50+
return m_providedInputFeatureValues;
51+
}
4852
const std::wstring& CsvPath() const { return m_csvData; }
4953
const std::wstring& OutputPath() const { return m_perfOutputPath; }
5054
const std::wstring& FolderPath() const { return m_modelFolderPath; }
@@ -92,11 +96,11 @@ class CommandLineArgs
9296
bool IsGarbageInput() const
9397
{
9498
// When there is no image or csv input provided, then garbage input binding is used.
95-
return m_imagePaths.empty() && m_csvData.empty();
99+
return m_imagePaths.empty() && m_csvData.empty() && m_providedInputFeatureValues.empty();
96100
}
97101
bool IsCSVInput() const { return m_imagePaths.empty() && !m_csvData.empty(); }
98102
bool IsImageInput() const { return !m_imagePaths.empty() && m_csvData.empty(); }
99-
103+
bool InputFeatureValuesProvided() const { return !m_providedInputFeatureValues.empty(); }
100104
uint32_t NumIterations() const { return m_numIterations; }
101105
uint32_t NumLoadIterations() const { return m_numLoadIterations; }
102106
uint32_t NumSessionCreationIterations() const { return m_numSessionIterations; }
@@ -140,6 +144,8 @@ class CommandLineArgs
140144
void SetSessionCreationIterations(const uint32_t iterations) { m_numSessionIterations = iterations; }
141145
void SetLoadIterations(const uint32_t iterations) { m_numLoadIterations = iterations; }
142146
void AddPerformanceFileMetadata(const std::string& key, const std::string& value);
147+
void AddProvidedInputFeatureValue(const ILearningModelFeatureValue& input);
148+
void ClearProvidedInputFeatureValues() { m_providedInputFeatureValues.clear(); };
143149
void SetGarbageDataMaxValue(const uint32_t value) { m_garbageDataMaxValue = value; }
144150

145151
// Stop iterating when total time of iterations after the first iteration exceeds time limit.
@@ -185,6 +191,7 @@ class CommandLineArgs
185191
std::wstring m_modelFolderPath;
186192
std::wstring m_modelPath;
187193
std::vector<std::wstring> m_imagePaths;
194+
std::vector<ILearningModelFeatureValue> m_providedInputFeatureValues;
188195
std::wstring m_inputImageFolderPath;
189196
std::wstring m_csvData;
190197
std::wstring m_inputData;

Tools/WinMLRunner/src/Run.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -221,17 +221,23 @@ HRESULT BindInputs(LearningModelBinding& context, const LearningModelSession& se
221221
bool captureIterationPerf = args.IsPerformanceCapture() || args.IsPerIterationCapture();
222222

223223
std::vector<ILearningModelFeatureValue> inputFeatures;
224-
try
224+
if (args.InputFeatureValuesProvided())
225225
{
226-
inputFeatures = GenerateInputFeatures(session.Model(), args, inputBindingType, inputDataType, device, iteration, imagePath);
226+
inputFeatures = args.ProvidedInputFeatureValues();
227227
}
228-
catch (hresult_error hr)
228+
else
229229
{
230-
std::wcout << "\nGenerating Input Features [FAILED]" << std::endl;
231-
std::wcout << hr.message().c_str() << std::endl;
232-
return hr.code();
230+
try
231+
{
232+
inputFeatures = GenerateInputFeatures(session.Model(), args, inputBindingType, inputDataType, device, iteration, imagePath);
233+
}
234+
catch (hresult_error hr)
235+
{
236+
std::wcout << "\nGenerating Input Features [FAILED]" << std::endl;
237+
std::wcout << hr.message().c_str() << std::endl;
238+
return hr.code();
239+
}
233240
}
234-
235241
HRESULT bindInputResult =
236242
BindInputFeatures(session.Model(), context, inputFeatures, args, output, captureIterationPerf, iteration, profiler);
237243

0 commit comments

Comments
 (0)