33using Microsoft . UI . Xaml . Controls ;
44using System ;
55using System . Collections . Generic ;
6+ using System . IO ;
67using System . Threading . Tasks ;
78using Windows . Graphics . Imaging ;
89using Windows . Media ;
@@ -19,14 +20,15 @@ public sealed class EvalResult
1920
2021 public sealed partial class Batching : Page
2122 {
22- const int numInputImages = 50 ;
23- const int numEvalIterations = 100 ;
23+ const int NumInputImages = 50 ;
24+ const int NumEvalIterations = 100 ;
2425
25- private LearningModelSession nonBatchingSession_ ;
26- private LearningModelSession batchingSession_ ;
27-
28- float avgNonBatchedDuration_ = 0 ;
29- float avgBatchDuration_ = 0 ;
26+ private LearningModel _model = null ;
27+ private LearningModelSession _nonBatchingSession = null ;
28+ private LearningModelSession _batchingSession = null ;
29+
30+ float _avgNonBatchedDuration = 0 ;
31+ float _avgBatchDuration = 0 ;
3032
3133 // Marked volatile since it's updated across threads
3234 static volatile bool navigatingAwayFromPage = false ;
@@ -36,28 +38,35 @@ public Batching()
3638 {
3739 this . InitializeComponent ( ) ;
3840 // Ensure static variable is always false on page initialization
39- navigatingAwayFromPage = false ;
41+ navigatingAwayFromPage = false ;
42+
43+ // Load the model
44+ var modelName = "squeezenet1.1-7-batched.onnx" ;
45+ var modelPath = Path . Join ( Windows . ApplicationModel . Package . Current . InstalledLocation . Path , "Models" , modelName ) ;
46+ _model = LearningModel . LoadFromFilePath ( modelPath ) ;
4047 }
4148
4249 async private void StartInference ( object sender , RoutedEventArgs e )
4350 {
44- ShowEvalUI ( ) ;
45- ResetEvalMetrics ( ) ;
51+ ShowStatus ( ) ;
52+ ResetMetrics ( ) ;
4653
4754 var inputImages = await GetInputImages ( ) ;
4855 int batchSize = GetBatchSizeFromBatchSizeSlider ( ) ;
49- await CreateSessions ( batchSize ) ;
5056
51- UpdateEvalText ( false ) ;
57+ _nonBatchingSession = await CreateLearningModelSession ( _model ) ;
58+ _batchingSession = await CreateLearningModelSession ( _model , batchSize ) ;
59+
60+ UpdateStatus ( false ) ;
5261 await Classify ( inputImages ) ;
5362
54- UpdateEvalText ( true ) ;
63+ UpdateStatus ( true ) ;
5564 await ClassifyBatched ( inputImages , batchSize ) ;
5665
57- GenerateEvalResultAndUI ( ) ;
66+ ShowUI ( ) ;
5867 }
5968
60- private void ShowEvalUI ( )
69+ private void ShowStatus ( )
6170 {
6271 StartInferenceBtn . IsEnabled = false ;
6372 BatchSizeSlider . IsEnabled = false ;
@@ -66,10 +75,10 @@ private void ShowEvalUI()
6675 LoadingContainer . Visibility = Visibility . Visible ;
6776 }
6877
69- private void ResetEvalMetrics ( )
78+ private void ResetMetrics ( )
7079 {
71- avgNonBatchedDuration_ = 0 ;
72- avgBatchDuration_ = 0 ;
80+ _avgNonBatchedDuration = 0 ;
81+ _avgBatchDuration = 0 ;
7382 }
7483
7584 // Test input consists of 50 images (25 bird and 25 cat)
@@ -80,7 +89,7 @@ private async Task<List<VideoFrame>> GetInputImages()
8089 var birdImage = await CreateSoftwareBitmapFromStorageFile ( birdFile ) ;
8190 var catImage = await CreateSoftwareBitmapFromStorageFile ( catFile ) ;
8291 var inputImages = new List < VideoFrame > ( ) ;
83- for ( int i = 0 ; i < numInputImages / 2 ; i ++ )
92+ for ( int i = 0 ; i < NumInputImages / 2 ; i ++ )
8493 {
8594 inputImages . Add ( VideoFrame . CreateWithSoftwareBitmap ( birdImage ) ) ;
8695 inputImages . Add ( VideoFrame . CreateWithSoftwareBitmap ( catImage ) ) ;
@@ -96,30 +105,23 @@ private async Task<SoftwareBitmap> CreateSoftwareBitmapFromStorageFile(StorageFi
96105 return bitmap ;
97106 }
98107
99- private void UpdateEvalText ( bool isBatchingEval )
100- {
101- if ( isBatchingEval )
102- EvalText . Text = "Inferencing Batched Inputs:" ;
103- else
104- EvalText . Text = "Inferencing Non-Batched Inputs:" ;
105- }
106-
107- private async Task CreateSessions ( int batchSizeOverride )
108+ private void UpdateStatus ( bool isBatchingEval )
108109 {
109- var modelPath = "ms-appx:///Models/squeezenet1.1-7-batched.onnx" ;
110- nonBatchingSession_ = await CreateLearningModelSession ( modelPath ) ;
111- batchingSession_ = await CreateLearningModelSession ( modelPath , batchSizeOverride ) ;
110+ if ( isBatchingEval )
111+ {
112+ EvalText . Text = "Inferencing Batched Inputs:" ;
113+ }
114+ else
115+ {
116+ EvalText . Text = "Inferencing Non-Batched Inputs:" ;
117+ }
112118 }
113119
114- private async Task < LearningModelSession > CreateLearningModelSession ( string modelPath , int batchSizeOverride = - 1 )
120+ private async Task < LearningModelSession > CreateLearningModelSession ( LearningModel model , int batchSizeOverride = - 1 )
115121 {
116- var model = await CreateLearningModel ( modelPath ) ;
117122 var deviceKind = DeviceComboBox . GetDeviceKind ( ) ;
118123 var device = new LearningModelDevice ( deviceKind ) ;
119- var options = new LearningModelSessionOptions ( )
120- {
121- CloseModelOnSessionCreation = true // Close the model to prevent extra memory usage
122- } ;
124+ var options = new LearningModelSessionOptions ( ) ;
123125 if ( batchSizeOverride > 0 )
124126 {
125127 options . BatchSizeOverride = ( uint ) batchSizeOverride ;
@@ -128,26 +130,21 @@ private async Task<LearningModelSession> CreateLearningModelSession(string model
128130 return session ;
129131 }
130132
131- private static async Task < LearningModel > CreateLearningModel ( string modelPath )
132- {
133- var uri = new Uri ( modelPath ) ;
134- var file = await StorageFile . GetFileFromApplicationUriAsync ( uri ) ;
135- var model = await LearningModel . LoadFromStorageFileAsync ( file ) ;
136- return model ;
137- }
138-
139133 async private Task Classify ( List < VideoFrame > inputImages )
140134 {
141135 float totalEvalDurations = 0 ;
142- for ( int i = 0 ; i < numEvalIterations ; i ++ )
136+ for ( int i = 0 ; i < NumEvalIterations ; i ++ )
143137 {
144- if ( navigatingAwayFromPage )
145- break ;
146- UpdateEvalProgressUI ( i ) ;
147- float evalDuration = await Task . Run ( ( ) => Evaluate ( nonBatchingSession_ , inputImages ) ) ;
138+ if ( navigatingAwayFromPage )
139+ {
140+ break ;
141+ }
142+
143+ UpdateProgress ( i ) ;
144+ float evalDuration = await Task . Run ( ( ) => Evaluate ( _nonBatchingSession , inputImages ) ) ;
148145 totalEvalDurations += evalDuration ;
149146 }
150- avgNonBatchedDuration_ = totalEvalDurations / numEvalIterations ;
147+ _avgNonBatchedDuration = totalEvalDurations / NumEvalIterations ;
151148 }
152149
153150 private static float Evaluate ( LearningModelSession session , List < VideoFrame > input )
@@ -157,8 +154,11 @@ private static float Evaluate(LearningModelSession session, List<VideoFrame> inp
157154 var binding = new LearningModelBinding ( session ) ;
158155 for ( int j = 0 ; j < input . Count ; j ++ )
159156 {
160- if ( navigatingAwayFromPage )
161- break ;
157+ if ( navigatingAwayFromPage )
158+ {
159+ break ;
160+ }
161+
162162 var start = HighResolutionClock . UtcNow ( ) ;
163163 binding . Bind ( inputName , input [ j ] ) ;
164164 session . Evaluate ( binding , "" ) ;
@@ -172,15 +172,15 @@ private static float Evaluate(LearningModelSession session, List<VideoFrame> inp
172172 async private Task ClassifyBatched ( List < VideoFrame > inputImages , int batchSize )
173173 {
174174 float totalEvalDurations = 0 ;
175- for ( int i = 0 ; i < numEvalIterations ; i ++ )
175+ for ( int i = 0 ; i < NumEvalIterations ; i ++ )
176176 {
177177 if ( navigatingAwayFromPage )
178178 break ;
179- UpdateEvalProgressUI ( i ) ;
180- float evalDuration = await Task . Run ( ( ) => EvaluateBatched ( batchingSession_ , inputImages , batchSize ) ) ;
179+ UpdateProgress ( i ) ;
180+ float evalDuration = await Task . Run ( ( ) => EvaluateBatched ( _batchingSession , inputImages , batchSize ) ) ;
181181 totalEvalDurations += evalDuration ;
182182 }
183- avgBatchDuration_ = totalEvalDurations / numEvalIterations ;
183+ _avgBatchDuration = totalEvalDurations / NumEvalIterations ;
184184 }
185185
186186 private static float EvaluateBatched ( LearningModelSession session , List < VideoFrame > input , int batchSize )
@@ -191,8 +191,11 @@ private static float EvaluateBatched(LearningModelSession session, List<VideoFra
191191 var binding = new LearningModelBinding ( session ) ;
192192 for ( int i = 0 ; i < numBatches ; i ++ )
193193 {
194- if ( navigatingAwayFromPage )
195- break ;
194+ if ( navigatingAwayFromPage )
195+ {
196+ break ;
197+ }
198+
196199 int rangeStart = batchSize * i ;
197200 List < VideoFrame > batch ;
198201 // Add padding to the last batch if necessary
@@ -222,19 +225,19 @@ private int GetBatchSizeFromBatchSizeSlider()
222225 return int . Parse ( BatchSizeSlider . Value . ToString ( ) ) ;
223226 }
224227
225- private void UpdateEvalProgressUI ( int attemptNumber )
228+ private void UpdateProgress ( int attemptNumber )
226229 {
227- EvalProgressText . Text = "Attempt " + attemptNumber . ToString ( ) + "/" + numEvalIterations . ToString ( ) ;
230+ EvalProgressText . Text = "Attempt " + attemptNumber . ToString ( ) + "/" + NumEvalIterations . ToString ( ) ;
228231 EvalProgressBar . Value = attemptNumber + 1 ;
229232 }
230233
231- private void GenerateEvalResultAndUI ( )
234+ private void ShowUI ( )
232235 {
233- float ratio = ( 1 - ( avgBatchDuration_ / avgNonBatchedDuration_ ) ) * 100 ;
236+ float ratio = ( 1 - ( _avgBatchDuration / _avgNonBatchedDuration ) ) * 100 ;
234237 var evalResult = new EvalResult
235238 {
236- nonBatchedAvgTime = avgNonBatchedDuration_ . ToString ( "0.00" ) ,
237- batchedAvgTime = avgBatchDuration_ . ToString ( "0.00" ) ,
239+ nonBatchedAvgTime = _avgNonBatchedDuration . ToString ( "0.00" ) ,
240+ batchedAvgTime = _avgBatchDuration . ToString ( "0.00" ) ,
238241 timeRatio = ratio . ToString ( "0.0" )
239242 } ;
240243 List < EvalResult > results = new List < EvalResult > ( ) ;
0 commit comments