Skip to content

Commit 1a905be

Browse files
Added more logging to OBJ-DET (#6646)
* Added more logging to OBJ-DET * Make obj-det images smaller for test speed
1 parent 07b6f45 commit 1a905be

File tree

12 files changed

+49
-14
lines changed

12 files changed

+49
-14
lines changed

src/Microsoft.ML.TorchSharp/AutoFormerV2/ObjectDetectionTrainer.cs

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,11 @@ public sealed class Options : TransformInputBase
114114
/// Gets or sets the weight decay in optimizer.
115115
/// </summary>
116116
public double WeightDecay = 0.0;
117+
118+
/// <summary>
119+
/// How often to log the loss.
120+
/// </summary>
121+
public int LogEveryNStep = 50;
117122
}
118123

119124
private protected readonly IHost Host;
@@ -122,7 +127,7 @@ public sealed class Options : TransformInputBase
122127

123128
internal ObjectDetectionTrainer(IHostEnvironment env, Options options)
124129
{
125-
Host = Contracts.CheckRef(env, nameof(env)).Register(nameof(NasBertTrainer));
130+
Host = Contracts.CheckRef(env, nameof(env)).Register(nameof(ObjectDetectionTrainer));
126131
Contracts.Assert(options.MaxEpoch > 0);
127132
Contracts.AssertValue(options.BoundingBoxColumnName);
128133
Contracts.AssertValue(options.LabelColumnName);
@@ -163,14 +168,21 @@ public ObjectDetectionTransformer Fit(IDataView input)
163168
using (var ch = Host.Start("TrainModel"))
164169
using (var pch = Host.StartProgressChannel("Training model"))
165170
{
166-
var header = new ProgressHeader(new[] { "Accuracy" }, null);
171+
var header = new ProgressHeader(new[] { "Loss" }, new[] { "total images" });
172+
167173
var trainer = new Trainer(this, ch, input);
168-
pch.SetHeader(header, e => e.SetMetric(0, trainer.Accuracy));
174+
pch.SetHeader(header,
175+
e =>
176+
{
177+
e.SetProgress(0, trainer.Updates, trainer.RowCount);
178+
e.SetMetric(0, trainer.LossValue);
179+
});
180+
169181
for (int i = 0; i < Option.MaxEpoch; i++)
170182
{
171183
ch.Trace($"Starting epoch {i}");
172184
Host.CheckAlive();
173-
trainer.Train(Host, input);
185+
trainer.Train(Host, input, pch);
174186
ch.Trace($"Finished epoch {i}");
175187
}
176188
var labelCol = input.Schema.GetColumnOrNull(Option.LabelColumnName);
@@ -191,17 +203,19 @@ internal class Trainer
191203
protected readonly ObjectDetectionTrainer Parent;
192204
public FocalLoss Loss;
193205
public int Updates;
194-
public float Accuracy;
206+
public float LossValue;
207+
public readonly int RowCount;
208+
private readonly IChannel _channel;
195209

196210
public Trainer(ObjectDetectionTrainer parent, IChannel ch, IDataView input)
197211
{
198212
Parent = parent;
199213
Updates = 0;
200-
Accuracy = 0;
201-
214+
LossValue = 0;
215+
_channel = ch;
202216

203217
// Get row count and figure out num of unique labels
204-
var rowCount = GetRowCountAndSetLabelCount(input);
218+
RowCount = GetRowCountAndSetLabelCount(input);
205219
Device = TorchUtils.InitializeDevice(Parent.Host);
206220

207221
// Initialize the model and load pre-trained weights
@@ -274,7 +288,7 @@ private string GetModelPath()
274288
return relativeFilePath;
275289
}
276290

277-
public void Train(IHost host, IDataView input)
291+
public void Train(IHost host, IDataView input, IProgressChannel pch)
278292
{
279293
// Get the cursor and the correct columns based on the inputs
280294
DataViewRowCursor cursor = input.GetRowCursor(input.Schema[Parent.Option.LabelColumnName], input.Schema[Parent.Option.BoundingBoxColumnName], input.Schema[Parent.Option.ImageColumnName]);
@@ -302,7 +316,7 @@ public void Train(IHost host, IDataView input)
302316

303317
while (cursorValid)
304318
{
305-
cursorValid = TrainStep(host, cursor, boundingBoxGetter, imageGetter, labelGetter);
319+
cursorValid = TrainStep(host, cursor, boundingBoxGetter, imageGetter, labelGetter, pch);
306320
}
307321

308322
LearningRateScheduler.step();
@@ -312,7 +326,8 @@ private bool TrainStep(IHost host,
312326
DataViewRowCursor cursor,
313327
ValueGetter<VBuffer<float>> boundingBoxGetter,
314328
ValueGetter<MLImage> imageGetter,
315-
ValueGetter<VBuffer<uint>> labelGetter)
329+
ValueGetter<VBuffer<uint>> labelGetter,
330+
IProgressChannel pch)
316331
{
317332
using var disposeScope = torch.NewDisposeScope();
318333
var cursorValid = true;
@@ -343,6 +358,12 @@ private bool TrainStep(IHost host,
343358
Optimizer.step();
344359
host.CheckAlive();
345360

361+
if (Updates % Parent.Option.LogEveryNStep == 0)
362+
{
363+
pch.Checkpoint(lossValue.ToDouble(), Updates);
364+
_channel.Info($"Row: {Updates}, Loss: {lossValue.ToDouble()}");
365+
}
366+
346367
return cursorValid;
347368
}
348369

test/Microsoft.ML.Tests/ObjectDetectionTests.cs

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66
using System.Linq;
77
using Microsoft.ML.Data;
88
using Microsoft.ML.RunTests;
9-
using Microsoft.ML.Transforms.Image;
109
using Microsoft.VisualBasic;
1110
using Microsoft.ML.TorchSharp;
1211
using Xunit;
1312
using Xunit.Abstractions;
1413
using Microsoft.ML.TorchSharp.AutoFormerV2;
14+
using Microsoft.ML.Runtime;
15+
using System.Collections.Generic;
1516

1617
namespace Microsoft.ML.Tests
1718
{
@@ -50,13 +51,13 @@ public void SimpleObjDetectionTest()
5051
.Append(ML.MulticlassClassification.Trainers.ObjectDetection("Labels", boundingBoxColumnName: "Box", maxEpoch: 1))
5152
.Append(ML.Transforms.Conversion.MapKeyToValue("PredictedLabel"));
5253

53-
5454
var options = new ObjectDetectionTrainer.Options()
5555
{
5656
LabelColumnName = "Labels",
5757
BoundingBoxColumnName = "Box",
5858
ScoreThreshold = .5,
59-
MaxEpoch = 1
59+
MaxEpoch = 1,
60+
LogEveryNStep = 1,
6061
};
6162

6263
var pipeline = ML.Transforms.Text.TokenizeIntoWords("Labels", separators: new char[] { ',' })
@@ -67,13 +68,26 @@ public void SimpleObjDetectionTest()
6768
.Append(ML.MulticlassClassification.Trainers.ObjectDetection(options))
6869
.Append(ML.Transforms.Conversion.MapKeyToValue("PredictedLabel"));
6970

71+
var logs = new List<LoggingEventArgs>();
72+
73+
ML.Log += (o, e) =>
74+
{
75+
if (e.Source.StartsWith("ObjectDetectionTrainer") && e.Kind == ChannelMessageKind.Info && e.Message.Contains("Loss:"))
76+
{
77+
logs.Add(e);
78+
}
79+
};
80+
7081
var model = pipeline.Fit(data);
7182
var idv = model.Transform(data);
7283
// Make sure the metrics work.
7384
var metrics = ML.MulticlassClassification.EvaluateObjectDetection(idv, idv.Schema[2], idv.Schema["Box"], idv.Schema["PredictedLabel"], idv.Schema["PredictedBoundingBoxes"], idv.Schema["Score"]);
7485
Assert.True(!float.IsNaN(metrics.MAP50));
7586
Assert.True(!float.IsNaN(metrics.MAP50_95));
7687

88+
// We aren't doing enough training to get a consistent loss, so just make sure its being logged
89+
Assert.True(logs.Count > 0);
90+
7791
// Make sure the filtered pipeline can run without any columns but image column AFTER training
7892
var dataFiltered = TextLoader.Create(ML, new TextLoader.Options()
7993
{
-262 KB
Loading
-263 KB
Loading
-239 KB
Loading
-185 KB
Loading
-184 KB
Loading
-185 KB
Loading
-178 KB
Loading
-181 KB
Loading

0 commit comments

Comments
 (0)