@@ -114,6 +114,11 @@ public sealed class Options : TransformInputBase
114
114
/// Gets or sets the weight decay in optimizer.
115
115
/// </summary>
116
116
public double WeightDecay = 0.0 ;
117
+
118
+ /// <summary>
119
+ /// How often to log the loss.
120
+ /// </summary>
121
+ public int LogEveryNStep = 50 ;
117
122
}
118
123
119
124
private protected readonly IHost Host ;
@@ -122,7 +127,7 @@ public sealed class Options : TransformInputBase
122
127
123
128
internal ObjectDetectionTrainer ( IHostEnvironment env , Options options )
124
129
{
125
- Host = Contracts . CheckRef ( env , nameof ( env ) ) . Register ( nameof ( NasBertTrainer ) ) ;
130
+ Host = Contracts . CheckRef ( env , nameof ( env ) ) . Register ( nameof ( ObjectDetectionTrainer ) ) ;
126
131
Contracts . Assert ( options . MaxEpoch > 0 ) ;
127
132
Contracts . AssertValue ( options . BoundingBoxColumnName ) ;
128
133
Contracts . AssertValue ( options . LabelColumnName ) ;
@@ -163,14 +168,21 @@ public ObjectDetectionTransformer Fit(IDataView input)
163
168
using ( var ch = Host . Start ( "TrainModel" ) )
164
169
using ( var pch = Host . StartProgressChannel ( "Training model" ) )
165
170
{
166
- var header = new ProgressHeader ( new [ ] { "Accuracy" } , null ) ;
171
+ var header = new ProgressHeader ( new [ ] { "Loss" } , new [ ] { "total images" } ) ;
172
+
167
173
var trainer = new Trainer ( this , ch , input ) ;
168
- pch . SetHeader ( header , e => e . SetMetric ( 0 , trainer . Accuracy ) ) ;
174
+ pch . SetHeader ( header ,
175
+ e =>
176
+ {
177
+ e . SetProgress ( 0 , trainer . Updates , trainer . RowCount ) ;
178
+ e . SetMetric ( 0 , trainer . LossValue ) ;
179
+ } ) ;
180
+
169
181
for ( int i = 0 ; i < Option . MaxEpoch ; i ++ )
170
182
{
171
183
ch . Trace ( $ "Starting epoch { i } ") ;
172
184
Host . CheckAlive ( ) ;
173
- trainer . Train ( Host , input ) ;
185
+ trainer . Train ( Host , input , pch ) ;
174
186
ch . Trace ( $ "Finished epoch { i } ") ;
175
187
}
176
188
var labelCol = input . Schema . GetColumnOrNull ( Option . LabelColumnName ) ;
@@ -191,17 +203,19 @@ internal class Trainer
191
203
protected readonly ObjectDetectionTrainer Parent ;
192
204
public FocalLoss Loss ;
193
205
public int Updates ;
194
- public float Accuracy ;
206
+ public float LossValue ;
207
+ public readonly int RowCount ;
208
+ private readonly IChannel _channel ;
195
209
196
210
public Trainer ( ObjectDetectionTrainer parent , IChannel ch , IDataView input )
197
211
{
198
212
Parent = parent ;
199
213
Updates = 0 ;
200
- Accuracy = 0 ;
201
-
214
+ LossValue = 0 ;
215
+ _channel = ch ;
202
216
203
217
// Get row count and figure out num of unique labels
204
- var rowCount = GetRowCountAndSetLabelCount ( input ) ;
218
+ RowCount = GetRowCountAndSetLabelCount ( input ) ;
205
219
Device = TorchUtils . InitializeDevice ( Parent . Host ) ;
206
220
207
221
// Initialize the model and load pre-trained weights
@@ -274,7 +288,7 @@ private string GetModelPath()
274
288
return relativeFilePath ;
275
289
}
276
290
277
- public void Train ( IHost host , IDataView input )
291
+ public void Train ( IHost host , IDataView input , IProgressChannel pch )
278
292
{
279
293
// Get the cursor and the correct columns based on the inputs
280
294
DataViewRowCursor cursor = input . GetRowCursor ( input . Schema [ Parent . Option . LabelColumnName ] , input . Schema [ Parent . Option . BoundingBoxColumnName ] , input . Schema [ Parent . Option . ImageColumnName ] ) ;
@@ -302,7 +316,7 @@ public void Train(IHost host, IDataView input)
302
316
303
317
while ( cursorValid )
304
318
{
305
- cursorValid = TrainStep ( host , cursor , boundingBoxGetter , imageGetter , labelGetter ) ;
319
+ cursorValid = TrainStep ( host , cursor , boundingBoxGetter , imageGetter , labelGetter , pch ) ;
306
320
}
307
321
308
322
LearningRateScheduler . step ( ) ;
@@ -312,7 +326,8 @@ private bool TrainStep(IHost host,
312
326
DataViewRowCursor cursor ,
313
327
ValueGetter < VBuffer < float > > boundingBoxGetter ,
314
328
ValueGetter < MLImage > imageGetter ,
315
- ValueGetter < VBuffer < uint > > labelGetter )
329
+ ValueGetter < VBuffer < uint > > labelGetter ,
330
+ IProgressChannel pch )
316
331
{
317
332
using var disposeScope = torch . NewDisposeScope ( ) ;
318
333
var cursorValid = true ;
@@ -343,6 +358,12 @@ private bool TrainStep(IHost host,
343
358
Optimizer . step ( ) ;
344
359
host . CheckAlive ( ) ;
345
360
361
+ if ( Updates % Parent . Option . LogEveryNStep == 0 )
362
+ {
363
+ pch . Checkpoint ( lossValue . ToDouble ( ) , Updates ) ;
364
+ _channel . Info ( $ "Row: { Updates } , Loss: { lossValue . ToDouble ( ) } ") ;
365
+ }
366
+
346
367
return cursorValid ;
347
368
}
348
369
0 commit comments