@@ -124,33 +124,18 @@ HRESULT LoadModel(LearningModel& model, const std::wstring& path, bool capturePe
124124 return S_OK;
125125}
126126
127- void PopulateSessionOptions (LearningModelSessionOptions& sessionOptions)
128- {
129- // Batch Size Override as 1
130- try
131- {
132- sessionOptions.BatchSizeOverride (1 );
133- }
134- catch (...)
135- {
136- printf (" Batch size override couldn't be set.\n " );
137- throw ;
138- }
139- }
140-
141127void CreateSessionConsideringSupportForSessionOptions (LearningModelSession& session,
142128 LearningModel& model,
143129 Profiler<WINML_MODEL_TEST_PERF>& profiler,
144130 CommandLineArgs& args,
145- const LearningModelDeviceWithMetadata& learningModelDevice)
131+ const LearningModelDeviceWithMetadata& learningModelDevice,
132+ const LearningModelSessionOptions& sessionOptions)
146133{
147134 auto statics = get_activation_factory<ApiInformation, IApiInformationStatics>();
148135 bool isSessionOptionsTypePresent = isSessionOptionsTypePresent =
149136 statics.IsTypePresent (L" Windows.AI.MachineLearning.LearningModelSessionOptions" );
150137 if (isSessionOptionsTypePresent)
151138 {
152- LearningModelSessionOptions sessionOptions;
153- PopulateSessionOptions (sessionOptions);
154139 if (args.IsPerformanceCapture ())
155140 {
156141 WINML_PROFILING_START (profiler, WINML_MODEL_TEST_PERF::CREATE_SESSION);
@@ -175,16 +160,21 @@ void CreateSessionConsideringSupportForSessionOptions(LearningModelSession& sess
175160 }
176161}
177162
178- HRESULT CreateSession (LearningModelSession& session, LearningModel& model, const LearningModelDeviceWithMetadata& learningModelDevice,
179- CommandLineArgs& args, OutputHelper& output, Profiler<WINML_MODEL_TEST_PERF>& profiler)
163+ HRESULT CreateSession (LearningModelSession& session,
164+ LearningModel& model,
165+ const LearningModelDeviceWithMetadata& learningModelDevice,
166+ CommandLineArgs& args,
167+ OutputHelper& output,
168+ Profiler<WINML_MODEL_TEST_PERF>& profiler,
169+ const LearningModelSessionOptions& sessionOptions)
180170{
181171 if (model == nullptr )
182172 {
183173 return hresult_invalid_argument ().code ();
184174 }
185175 try
186176 {
187- CreateSessionConsideringSupportForSessionOptions (session, model, profiler, args, learningModelDevice);
177+ CreateSessionConsideringSupportForSessionOptions (session, model, profiler, args, learningModelDevice, sessionOptions );
188178 }
189179 catch (hresult_error hr)
190180 {
@@ -535,7 +525,8 @@ void RunConfiguration(CommandLineArgs& args, OutputHelper& output, LearningModel
535525}
536526int run (CommandLineArgs& args,
537527 Profiler<WINML_MODEL_TEST_PERF>& profiler,
538- const std::vector<LearningModelDeviceWithMetadata>& deviceList) try
528+ const std::vector<LearningModelDeviceWithMetadata>& deviceList,
529+ const LearningModelSessionOptions& sessionOptions) try
539530{
540531 // Initialize COM in a multi-threaded environment.
541532 winrt::init_apartment ();
@@ -600,7 +591,7 @@ int run(CommandLineArgs& args,
600591 sessionCreationIteration < args.NumSessionCreationIterations ();
601592 sessionCreationIteration++)
602593 {
603- lastHr = CreateSession (session, model, learningModelDevice,args, output, profiler);
594+ lastHr = CreateSession (session, model, learningModelDevice,args, output, profiler, sessionOptions );
604595 if (FAILED (lastHr))
605596 {
606597 continue ;
0 commit comments