Skip to content

Commit 29e7c24

Browse files
author
Ryan Lai
authored
CHange winmlrunner to accept session options through static lib (#337)
1 parent a8d6864 commit 29e7c24

File tree

3 files changed

+32
-24
lines changed

3 files changed

+32
-24
lines changed

Tools/WinMLRunner/src/Run.cpp

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -124,33 +124,18 @@ HRESULT LoadModel(LearningModel& model, const std::wstring& path, bool capturePe
124124
return S_OK;
125125
}
126126

127-
void PopulateSessionOptions(LearningModelSessionOptions& sessionOptions)
128-
{
129-
// Batch Size Override as 1
130-
try
131-
{
132-
sessionOptions.BatchSizeOverride(1);
133-
}
134-
catch (...)
135-
{
136-
printf("Batch size override couldn't be set.\n");
137-
throw;
138-
}
139-
}
140-
141127
void CreateSessionConsideringSupportForSessionOptions(LearningModelSession& session,
142128
LearningModel& model,
143129
Profiler<WINML_MODEL_TEST_PERF>& profiler,
144130
CommandLineArgs& args,
145-
const LearningModelDeviceWithMetadata& learningModelDevice)
131+
const LearningModelDeviceWithMetadata& learningModelDevice,
132+
const LearningModelSessionOptions& sessionOptions)
146133
{
147134
auto statics = get_activation_factory<ApiInformation, IApiInformationStatics>();
148135
bool isSessionOptionsTypePresent = isSessionOptionsTypePresent =
149136
statics.IsTypePresent(L"Windows.AI.MachineLearning.LearningModelSessionOptions");
150137
if (isSessionOptionsTypePresent)
151138
{
152-
LearningModelSessionOptions sessionOptions;
153-
PopulateSessionOptions(sessionOptions);
154139
if (args.IsPerformanceCapture())
155140
{
156141
WINML_PROFILING_START(profiler, WINML_MODEL_TEST_PERF::CREATE_SESSION);
@@ -175,16 +160,21 @@ void CreateSessionConsideringSupportForSessionOptions(LearningModelSession& sess
175160
}
176161
}
177162

178-
HRESULT CreateSession(LearningModelSession& session, LearningModel& model, const LearningModelDeviceWithMetadata& learningModelDevice,
179-
CommandLineArgs& args, OutputHelper& output, Profiler<WINML_MODEL_TEST_PERF>& profiler)
163+
HRESULT CreateSession(LearningModelSession& session,
164+
LearningModel& model,
165+
const LearningModelDeviceWithMetadata& learningModelDevice,
166+
CommandLineArgs& args,
167+
OutputHelper& output,
168+
Profiler<WINML_MODEL_TEST_PERF>& profiler,
169+
const LearningModelSessionOptions& sessionOptions)
180170
{
181171
if (model == nullptr)
182172
{
183173
return hresult_invalid_argument().code();
184174
}
185175
try
186176
{
187-
CreateSessionConsideringSupportForSessionOptions(session, model, profiler, args, learningModelDevice);
177+
CreateSessionConsideringSupportForSessionOptions(session, model, profiler, args, learningModelDevice, sessionOptions);
188178
}
189179
catch (hresult_error hr)
190180
{
@@ -535,7 +525,8 @@ void RunConfiguration(CommandLineArgs& args, OutputHelper& output, LearningModel
535525
}
536526
int run(CommandLineArgs& args,
537527
Profiler<WINML_MODEL_TEST_PERF>& profiler,
538-
const std::vector<LearningModelDeviceWithMetadata>& deviceList) try
528+
const std::vector<LearningModelDeviceWithMetadata>& deviceList,
529+
const LearningModelSessionOptions& sessionOptions) try
539530
{
540531
// Initialize COM in a multi-threaded environment.
541532
winrt::init_apartment();
@@ -600,7 +591,7 @@ int run(CommandLineArgs& args,
600591
sessionCreationIteration < args.NumSessionCreationIterations();
601592
sessionCreationIteration++)
602593
{
603-
lastHr = CreateSession(session, model, learningModelDevice,args, output, profiler);
594+
lastHr = CreateSession(session, model, learningModelDevice,args, output, profiler, sessionOptions);
604595
if (FAILED(lastHr))
605596
{
606597
continue;

Tools/WinMLRunner/src/Run.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@
33

44
int run(CommandLineArgs& args,
55
Profiler<WINML_MODEL_TEST_PERF>& profiler,
6-
const std::vector<LearningModelDeviceWithMetadata>& deviceList);
6+
const std::vector<LearningModelDeviceWithMetadata>& deviceList,
7+
const LearningModelSessionOptions& sessionOptions);

Tools/WinMLRunner/src/main.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,20 @@
55
#include <codecvt>
66
using namespace std;
77

8+
void PopulateSessionOptions(LearningModelSessionOptions& sessionOptions)
9+
{
10+
// Batch Size Override as 1
11+
try
12+
{
13+
sessionOptions.BatchSizeOverride(1);
14+
}
15+
catch (...)
16+
{
17+
printf("Batch size override couldn't be set.\n");
18+
throw;
19+
}
20+
}
21+
822
int main(int argc, char *argv[])
923
{
1024
std::wstring_convert<std::codecvt_utf8<wchar_t>, wchar_t> converter;
@@ -36,7 +50,9 @@ int main(int argc, char *argv[])
3650
wprintf(error.message().c_str());
3751
return error.code();
3852
}
39-
int returnCode = run(*commandLineArgs, profiler, deviceList);
53+
LearningModelSessionOptions sessionOptions;
54+
PopulateSessionOptions(sessionOptions);
55+
int returnCode = run(*commandLineArgs, profiler, deviceList, sessionOptions);
4056
free(commandLineArgs);
4157
return returnCode;
4258
}

0 commit comments

Comments
 (0)