Skip to content

Commit e3ffaa7

Browse files
authored
Add time limit stop iterating for winmlrunner static lib (#246)
1 parent 99e9342 commit e3ffaa7

File tree

2 files changed

+36
-11
lines changed

2 files changed

+36
-11
lines changed

Tools/WinMLRunner/src/CommandLineArgs.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class CommandLineArgs
2121
bool IsAutoScale() const { return m_autoScale; }
2222
bool IsOutputPerf() const { return m_perfOutput; }
2323
bool IsSaveTensor() const { return m_saveTensor; }
24-
24+
bool IsTimeLimitIterations() const { return m_timeLimitIterations; }
2525
BitmapInterpolationMode AutoScaleInterpMode() const { return m_autoScaleInterpMode; }
2626

2727
const std::vector<std::wstring>& ImagePaths() const { return m_imagePaths; }
@@ -76,6 +76,7 @@ class CommandLineArgs
7676
bool IsImageInput() const { return !m_imagePaths.empty() && m_csvData.empty(); }
7777

7878
uint32_t NumIterations() const { return m_numIterations; }
79+
double IterationTimeLimit() const { return m_iterationTimeLimitMilliseconds; }
7980
uint32_t NumThreads() const { return m_numThreads; }
8081
uint32_t ThreadInterval() const { return m_threadInterval; } // Thread interval in milliseconds
8182
uint32_t TopK() const { return m_topK; }
@@ -114,6 +115,12 @@ class CommandLineArgs
114115
{
115116
m_perfFileMetadata.push_back(std::make_pair(key, value));
116117
}
118+
// Stop iterating when total time of iterations after the first iteration exceeds time limit.
119+
void SetIterationTimeLimit(const double milliseconds)
120+
{
121+
m_timeLimitIterations = true;
122+
m_iterationTimeLimitMilliseconds = milliseconds;
123+
}
117124
std::wstring SaveTensorMode() const { return m_saveTensorMode; }
118125

119126
private:
@@ -139,6 +146,7 @@ class CommandLineArgs
139146
bool m_perfOutput = false;
140147
BitmapInterpolationMode m_autoScaleInterpMode = BitmapInterpolationMode::Cubic;
141148
bool m_saveTensor = false;
149+
bool m_timeLimitIterations = false;
142150
std::wstring m_saveTensorMode = L"First";
143151

144152
std::wstring m_modelFolderPath;
@@ -153,6 +161,7 @@ class CommandLineArgs
153161
std::wstring m_perfOutputPath;
154162
std::wstring m_perIterationDataPath;
155163
uint32_t m_numIterations = 1;
164+
double m_iterationTimeLimitMilliseconds = 0;
156165
uint32_t m_numThreads = 1;
157166
uint32_t m_threadInterval = 0;
158167
uint32_t m_topK = 1;

Tools/WinMLRunner/src/Run.cpp

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -660,43 +660,59 @@ void RunConfiguration(CommandLineArgs& args, OutputHelper& output, LearningModel
660660
const DeviceCreationLocation deviceCreationLocation, Profiler<WINML_MODEL_TEST_PERF>& profiler,
661661
const std::wstring& modelPath, const std::wstring& imagePath)
662662
{
663-
for (uint32_t i = 0; i < args.NumIterations(); i++)
663+
Timer iterationTimer;
664+
uint32_t iterationNum = 0;
665+
for (; iterationNum < args.NumIterations(); iterationNum++)
664666
{
665667
#if defined(_AMD64_)
666668
// PIX markers only work on AMD64
667669
// If PIX tool was attached then capture already began for the first iteration before
668670
// session creation. This is to begin PIX capture for each iteration after the first
669671
// iteration.
670-
if (i > 0)
672+
if (iterationNum > 0)
671673
{
672674
StartPIXCapture(output);
673675
}
674676
#endif
677+
if (args.IsTimeLimitIterations())
678+
{
679+
if (iterationNum == 1)
680+
{
681+
iterationTimer.Start();
682+
}
683+
else if (iterationNum >= 1 && iterationTimer.Stop() >= args.IterationTimeLimit())
684+
{
685+
std::cout << "Iteration time exceeded limit specified. Exiting.." << std::endl;
686+
break;
687+
}
688+
}
675689
LearningModelBinding context(session);
676690
lastHr = BindInputs(context, model, session, output, deviceType, args, inputBindingType, inputDataType,
677-
winrtDevice, deviceCreationLocation, i, profiler, imagePath);
691+
winrtDevice, deviceCreationLocation, iterationNum, profiler, imagePath);
678692
if (FAILED(lastHr))
679693
{
680694
break;
681695
}
682696
LearningModelEvaluationResult result = nullptr;
683697
bool capture_perf = args.IsPerformanceCapture() || args.IsPerIterationCapture();
684-
lastHr = EvaluateModel(result, model, context, session, args, output, capture_perf, i, profiler);
698+
lastHr = EvaluateModel(result, model, context, session, args, output, capture_perf, iterationNum, profiler);
685699
if (FAILED(lastHr))
686700
{
687-
output.PrintEvaluatingInfo(i + 1, deviceType, inputBindingType, inputDataType, deviceCreationLocation,
701+
output.PrintEvaluatingInfo(iterationNum + 1, deviceType, inputBindingType, inputDataType,
702+
deviceCreationLocation,
688703
"[FAILED]");
689704
break;
690705
}
691-
else if (!args.TerseOutput() || i == 0)
706+
else if (!args.TerseOutput() || iterationNum == 0)
692707
{
693-
output.PrintEvaluatingInfo(i + 1, deviceType, inputBindingType, inputDataType, deviceCreationLocation,
708+
output.PrintEvaluatingInfo(iterationNum + 1, deviceType, inputBindingType, inputDataType,
709+
deviceCreationLocation,
694710
"[SUCCESS]");
695711

696712
// Only print eval results on the first iteration, iff it's not garbage data
697713
if (!args.IsGarbageInput() || args.IsSaveTensor())
698714
{
699-
BindingUtilities::PrintOrSaveEvaluationResults(model, args, result.Outputs(), output, i);
715+
BindingUtilities::PrintOrSaveEvaluationResults(model, args, result.Outputs(), output, iterationNum);
700716
}
701717

702718
if (args.TerseOutput() && args.NumIterations() > 1)
@@ -713,15 +729,15 @@ void RunConfiguration(CommandLineArgs& args, OutputHelper& output, LearningModel
713729
// print metrics after iterations
714730
if (SUCCEEDED(lastHr) && args.IsPerformanceCapture())
715731
{
716-
output.PrintResults(profiler, args.NumIterations(), deviceType, inputBindingType, inputDataType,
732+
output.PrintResults(profiler, iterationNum, deviceType, inputBindingType, inputDataType,
717733
deviceCreationLocation, args.IsPerformanceConsoleOutputVerbose());
718734
if (args.IsOutputPerf())
719735
{
720736
std::string deviceTypeStringified = TypeHelper::Stringify(deviceType);
721737
std::string inputDataTypeStringified = TypeHelper::Stringify(inputDataType);
722738
std::string inputBindingTypeStringified = TypeHelper::Stringify(inputBindingType);
723739
std::string deviceCreationLocationStringified = TypeHelper::Stringify(deviceCreationLocation);
724-
output.WritePerformanceDataToCSV(profiler, args.NumIterations(), modelPath, deviceTypeStringified,
740+
output.WritePerformanceDataToCSV(profiler, iterationNum, modelPath, deviceTypeStringified,
725741
inputDataTypeStringified, inputBindingTypeStringified,
726742
deviceCreationLocationStringified, args.GetPerformanceFileMetadata());
727743
}

0 commit comments

Comments
 (0)