1- using Tensorflow . NumPy ;
21using System ;
32using System . Collections . Generic ;
43using System . Linq ;
4+ using Tensorflow ;
55using Tensorflow . Keras . ArgsDefinition ;
6+ using Tensorflow . Keras . Callbacks ;
67using Tensorflow . Keras . Engine . DataAdapters ;
7- using static Tensorflow . Binding ;
88using Tensorflow . Keras . Layers ;
99using Tensorflow . Keras . Utils ;
10- using Tensorflow ;
11- using Tensorflow . Keras . Callbacks ;
10+ using Tensorflow . NumPy ;
11+ using static Tensorflow . Binding ;
1212
1313namespace Tensorflow . Keras . Engine
1414{
@@ -27,7 +27,7 @@ public partial class Model
2727 /// <param name="use_multiprocessing"></param>
2828 /// <param name="return_dict"></param>
2929 /// <param name="is_val"></param>
30- public Dictionary < string , float > evaluate ( NDArray x , NDArray y ,
30+ public Dictionary < string , float > evaluate ( Tensor x , Tensor y ,
3131 int batch_size = - 1 ,
3232 int verbose = 1 ,
3333 int steps = - 1 ,
@@ -64,34 +64,11 @@ public Dictionary<string, float> evaluate(NDArray x, NDArray y,
6464 Verbose = verbose ,
6565 Steps = data_handler . Inferredsteps
6666 } ) ;
67- callbacks . on_test_begin ( ) ;
68-
69- //Dictionary<string, float>? logs = null;
70- var logs = new Dictionary < string , float > ( ) ;
71- foreach ( var ( epoch , iterator ) in data_handler . enumerate_epochs ( ) )
72- {
73- reset_metrics ( ) ;
74- // data_handler.catch_stop_iteration();
75-
76- foreach ( var step in data_handler . steps ( ) )
77- {
78- callbacks . on_test_batch_begin ( step ) ;
79- logs = test_function ( data_handler , iterator ) ;
80- var end_step = step + data_handler . StepIncrement ;
81- if ( is_val == false )
82- callbacks . on_test_batch_end ( end_step , logs ) ;
83- }
84- }
8567
86- var results = new Dictionary < string , float > ( ) ;
87- foreach ( var log in logs )
88- {
89- results [ log . Key ] = log . Value ;
90- }
91- return results ;
68+ return evaluate ( data_handler , callbacks , is_val , test_function ) ;
9269 }
9370
94- public Dictionary < string , float > evaluate ( IEnumerable < Tensor > x , NDArray y , int verbose = 1 , bool is_val = false )
71+ public Dictionary < string , float > evaluate ( IEnumerable < Tensor > x , Tensor y , int verbose = 1 , bool is_val = false )
9572 {
9673 var data_handler = new DataHandler ( new DataHandlerArgs
9774 {
@@ -107,34 +84,10 @@ public Dictionary<string, float> evaluate(IEnumerable<Tensor> x, NDArray y, int
10784 Verbose = verbose ,
10885 Steps = data_handler . Inferredsteps
10986 } ) ;
110- callbacks . on_test_begin ( ) ;
11187
112- Dictionary < string , float > logs = null ;
113- foreach ( var ( epoch , iterator ) in data_handler . enumerate_epochs ( ) )
114- {
115- reset_metrics ( ) ;
116- callbacks . on_epoch_begin ( epoch ) ;
117- // data_handler.catch_stop_iteration();
118-
119- foreach ( var step in data_handler . steps ( ) )
120- {
121- callbacks . on_test_batch_begin ( step ) ;
122- logs = test_step_multi_inputs_function ( data_handler , iterator ) ;
123- var end_step = step + data_handler . StepIncrement ;
124- if ( is_val == false )
125- callbacks . on_test_batch_end ( end_step , logs ) ;
126- }
127- }
128-
129- var results = new Dictionary < string , float > ( ) ;
130- foreach ( var log in logs )
131- {
132- results [ log . Key ] = log . Value ;
133- }
134- return results ;
88+ return evaluate ( data_handler , callbacks , is_val , test_step_multi_inputs_function ) ;
13589 }
13690
137-
13891 public Dictionary < string , float > evaluate ( IDatasetV2 x , int verbose = 1 , bool is_val = false )
13992 {
14093 var data_handler = new DataHandler ( new DataHandlerArgs
@@ -150,9 +103,24 @@ public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1, bool is
150103 Verbose = verbose ,
151104 Steps = data_handler . Inferredsteps
152105 } ) ;
106+
107+ return evaluate ( data_handler , callbacks , is_val , test_function ) ;
108+ }
109+
110+ /// <summary>
111+ /// Internal bare implementation of evaluate function.
112+ /// </summary>
113+ /// <param name="data_handler">Interations handling objects</param>
114+ /// <param name="callbacks"></param>
115+ /// <param name="test_func">The function to be called on each batch of data.</param>
116+ /// <param name="is_val">Whether it is validation or test.</param>
117+ /// <returns></returns>
118+ Dictionary < string , float > evaluate ( DataHandler data_handler , CallbackList callbacks , bool is_val , Func < DataHandler , Tensor [ ] , Dictionary < string , float > > test_func )
119+ {
153120 callbacks . on_test_begin ( ) ;
154121
155- Dictionary < string , float > logs = null ;
122+ var results = new Dictionary < string , float > ( ) ;
123+ var logs = results ;
156124 foreach ( var ( epoch , iterator ) in data_handler . enumerate_epochs ( ) )
157125 {
158126 reset_metrics ( ) ;
@@ -162,45 +130,47 @@ public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1, bool is
162130 foreach ( var step in data_handler . steps ( ) )
163131 {
164132 callbacks . on_test_batch_begin ( step ) ;
165- logs = test_function ( data_handler , iterator ) ;
133+
134+ logs = test_func ( data_handler , iterator . next ( ) ) ;
135+
136+ tf_with ( ops . control_dependencies ( Array . Empty < object > ( ) ) , ctl => _train_counter . assign_add ( 1 ) ) ;
137+
166138 var end_step = step + data_handler . StepIncrement ;
167- if ( is_val == false )
139+ if ( ! is_val )
168140 callbacks . on_test_batch_end ( end_step , logs ) ;
169141 }
142+
143+ if ( ! is_val )
144+ callbacks . on_epoch_end ( epoch , logs ) ;
170145 }
171146
172- var results = new Dictionary < string , float > ( ) ;
173147 foreach ( var log in logs )
174148 {
175149 results [ log . Key ] = log . Value ;
176150 }
151+
177152 return results ;
178153 }
179154
180- Dictionary < string , float > test_function ( DataHandler data_handler , OwnedIterator iterator )
155+ Dictionary < string , float > test_function ( DataHandler data_handler , Tensor [ ] data )
181156 {
182- var data = iterator . next ( ) ;
183- var outputs = test_step ( data_handler , data [ 0 ] , data [ 1 ] ) ;
184- tf_with ( ops . control_dependencies ( new object [ 0 ] ) , ctl => _test_counter . assign_add ( 1 ) ) ;
157+ var ( x , y ) = data_handler . DataAdapter . Expand1d ( data [ 0 ] , data [ 1 ] ) ;
158+
159+ var y_pred = Apply ( x , training : false ) ;
160+ var loss = compiled_loss . Call ( y , y_pred ) ;
161+
162+ compiled_metrics . update_state ( y , y_pred ) ;
163+
164+ var outputs = metrics . Select ( x => ( x . Name , x . result ( ) ) ) . ToDictionary ( x => x . Name , x => ( float ) x . Item2 ) ;
185165 return outputs ;
186166 }
187- Dictionary < string , float > test_step_multi_inputs_function ( DataHandler data_handler , OwnedIterator iterator )
167+
168+ Dictionary < string , float > test_step_multi_inputs_function ( DataHandler data_handler , Tensor [ ] data )
188169 {
189- var data = iterator . next ( ) ;
190170 var x_size = data_handler . DataAdapter . GetDataset ( ) . FirstInputTensorCount ;
191171 var outputs = train_step ( data_handler , new Tensors ( data . Take ( x_size ) . ToArray ( ) ) , new Tensors ( data . Skip ( x_size ) . ToArray ( ) ) ) ;
192172 tf_with ( ops . control_dependencies ( new object [ 0 ] ) , ctl => _train_counter . assign_add ( 1 ) ) ;
193173 return outputs ;
194174 }
195- Dictionary < string , float > test_step ( DataHandler data_handler , Tensor x , Tensor y )
196- {
197- ( x , y ) = data_handler . DataAdapter . Expand1d ( x , y ) ;
198- var y_pred = Apply ( x , training : false ) ;
199- var loss = compiled_loss . Call ( y , y_pred ) ;
200-
201- compiled_metrics . update_state ( y , y_pred ) ;
202-
203- return metrics . Select ( x => ( x . Name , x . result ( ) ) ) . ToDictionary ( x=> x . Item1 , x=> ( float ) x . Item2 ) ;
204- }
205175 }
206176}
0 commit comments