Skip to content

Commit fe2a442

Browse files
fix cancellation bug in SweepablePipelineRunner && fix object null exception in AutoML v1.0 regression API (#6560)
* update * update * blocking run task * revert tests * add more logging * fix #6558 * update * fix tests * fix build * fix tests * fix build error * add logger * increase time * fix tests
1 parent d239fda commit fe2a442

File tree

7 files changed

+44
-22
lines changed

7 files changed

+44
-22
lines changed

src/Microsoft.ML.AutoML/API/BinaryClassificationExperiment.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ public Task<TrialResult> RunAsync(TrialSettings settings, CancellationToken ct)
447447
_context?.CancelExecution();
448448
}))
449449
{
450-
return Task.Run(() => Run(settings));
450+
return Task.FromResult(Run(settings));
451451
}
452452
}
453453
catch (Exception ex) when (ct.IsCancellationRequested)

src/Microsoft.ML.AutoML/API/MulticlassClassificationExperiment.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ public Task<TrialResult> RunAsync(TrialSettings settings, CancellationToken ct)
445445
_context?.CancelExecution();
446446
}))
447447
{
448-
return Task.Run(() => Run(settings));
448+
return Task.FromResult(Run(settings));
449449
}
450450
}
451451
catch (Exception ex) when (ct.IsCancellationRequested)

src/Microsoft.ML.AutoML/API/RegressionExperiment.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ public override ExperimentResult<RegressionMetrics> Execute(IDataView trainData,
159159
int numCrossValFolds = 10;
160160
_experiment.SetDataset(trainData, numCrossValFolds);
161161
_pipeline = CreateRegressionPipeline(trainData, columnInformation, preFeaturizer);
162-
162+
_experiment.SetPipeline(_pipeline);
163163
TrialResultMonitor<RegressionMetrics> monitor = null;
164164
_experiment.SetMonitor((provider) =>
165165
{

src/Microsoft.ML.AutoML/AutoMLExperiment/AutoMLExperiment.cs

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,7 @@ private void InitializeServiceCollection()
5151
_serviceCollection.TryAddTransient((provider) =>
5252
{
5353
var contextManager = provider.GetRequiredService<IMLContextManager>();
54-
var trainingStopManager = provider.GetRequiredService<AggregateTrainingStopManager>();
5554
var context = contextManager.CreateMLContext();
56-
trainingStopManager.OnStopTraining += (s, e) =>
57-
{
58-
// only force-canceling running trials when there's completed trials.
59-
// otherwise, wait for the current running trial to be completed.
60-
if (_bestTrialResult != null)
61-
context.CancelExecution();
62-
};
6355

6456
return context;
6557
});

src/Microsoft.ML.AutoML/AutoMLExperiment/Runner/SweepablePipelineRunner.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ public Task<TrialResult> RunAsync(TrialSettings settings, CancellationToken ct)
9696
_mLContext?.CancelExecution();
9797
}))
9898
{
99-
return Task.Run(() => Run(settings));
99+
return Task.FromResult(Run(settings));
100100
}
101101
}
102102
catch (Exception ex) when (ct.IsCancellationRequested)

test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ public void AutoFit_UCI_Adult_Test()
4848
var trainData = textLoader.Load(dataPath);
4949
var settings = new BinaryExperimentSettings
5050
{
51-
MaxExperimentTimeInSeconds = 1,
51+
MaxModels = 1,
5252
};
5353

5454
settings.Trainers.Remove(BinaryClassificationTrainer.LightGbm);
@@ -75,7 +75,7 @@ public void AutoFit_UCI_Adult_Train_Test_Split_Test()
7575
var dataTrainTest = context.Data.TrainTestSplit(trainData);
7676
var settings = new BinaryExperimentSettings
7777
{
78-
MaxExperimentTimeInSeconds = 1,
78+
MaxModels = 1,
7979
};
8080

8181
settings.Trainers.Remove(BinaryClassificationTrainer.LightGbm);
@@ -101,7 +101,7 @@ public void AutoFit_UCI_Adult_CrossValidation_10_Test()
101101
var trainData = textLoader.Load(dataPath);
102102
var settings = new BinaryExperimentSettings
103103
{
104-
MaxExperimentTimeInSeconds = 1,
104+
MaxModels = 1,
105105
};
106106

107107
settings.Trainers.Remove(BinaryClassificationTrainer.LightGbm);
@@ -197,13 +197,23 @@ public void AutoFit_Taxi_Fare_Test()
197197
settings.Trainers.Remove(RegressionTrainer.StochasticDualCoordinateAscent);
198198
settings.Trainers.Remove(RegressionTrainer.LbfgsPoissonRegression);
199199

200+
// verify for dataset > 15000L
200201
var result = context.Auto()
201202
.CreateRegressionExperiment(settings)
202203
.Execute(dataset, label);
203204

204205
Assert.True(result.BestRun.ValidationMetrics.RSquared > 0.70);
205206
Assert.NotNull(result.BestRun.Estimator);
206207
Assert.NotNull(result.BestRun.TrainerName);
208+
209+
// verify for dataset < 15000L
210+
result = context.Auto()
211+
.CreateRegressionExperiment(settings)
212+
.Execute(context.Data.TakeRows(dataset, 1000), label);
213+
214+
Assert.True(result.BestRun.ValidationMetrics.RSquared > 0.70);
215+
Assert.NotNull(result.BestRun.Estimator);
216+
Assert.NotNull(result.BestRun.TrainerName);
207217
}
208218

209219
[Theory]
@@ -229,7 +239,7 @@ public void AutoFitMultiTest(bool useNumberOfCVFolds)
229239
uint numberOfCVFolds = 5;
230240
var settings = new MulticlassExperimentSettings
231241
{
232-
MaxExperimentTimeInSeconds = 1,
242+
MaxModels = 1,
233243
};
234244

235245
settings.Trainers.Remove(MulticlassClassificationTrainer.LightGbm);
@@ -257,7 +267,7 @@ public void AutoFitMultiTest(bool useNumberOfCVFolds)
257267
trainData = context.Data.TakeRows(trainData, crossValRowCountThreshold - 1);
258268
var settings = new MulticlassExperimentSettings
259269
{
260-
MaxExperimentTimeInSeconds = 1,
270+
MaxModels = 1,
261271
};
262272

263273
settings.Trainers.Remove(MulticlassClassificationTrainer.LightGbm);
@@ -286,8 +296,13 @@ public void AutoFitMultiClassification_Image_TrainTest()
286296
TrainTestData trainTestData = context.Data.TrainTestSplit(trainData, testFraction: 0.2, seed: 1);
287297
IDataView trainDataset = SplitUtil.DropAllColumnsExcept(context, trainTestData.TrainSet, originalColumnNames);
288298
IDataView testDataset = SplitUtil.DropAllColumnsExcept(context, trainTestData.TestSet, originalColumnNames);
299+
var settings = new MulticlassExperimentSettings
300+
{
301+
MaxModels = 1,
302+
};
303+
289304
var result = context.Auto()
290-
.CreateMulticlassClassificationExperiment(20)
305+
.CreateMulticlassClassificationExperiment(settings)
291306
.Execute(trainDataset, testDataset, columnInference.ColumnInformation);
292307

293308
result.BestRun.ValidationMetrics.MicroAccuracy.Should().BeGreaterThan(0.1);
@@ -305,8 +320,12 @@ public void AutoFitMultiClassification_Image_CV()
305320
var textLoader = context.Data.CreateTextLoader(columnInference.TextLoaderOptions);
306321
var trainData = context.Data.ShuffleRows(textLoader.Load(datasetPath), seed: 1);
307322
var originalColumnNames = trainData.Schema.Select(c => c.Name);
323+
var settings = new MulticlassExperimentSettings
324+
{
325+
MaxModels = 1,
326+
};
308327
var result = context.Auto()
309-
.CreateMulticlassClassificationExperiment(100)
328+
.CreateMulticlassClassificationExperiment(settings)
310329
.Execute(trainData, 5, columnInference.ColumnInformation);
311330

312331
result.BestRun.Results.Select(x => x.ValidationMetrics.MicroAccuracy).Max().Should().BeGreaterThan(0.1);
@@ -330,8 +349,12 @@ public void AutoFitMultiClassification_Image()
330349
var columnInference = context.Auto().InferColumns(datasetPath, "Label");
331350
var textLoader = context.Data.CreateTextLoader(columnInference.TextLoaderOptions);
332351
var trainData = textLoader.Load(datasetPath);
352+
var settings = new MulticlassExperimentSettings
353+
{
354+
MaxModels = 1,
355+
};
333356
var result = context.Auto()
334-
.CreateMulticlassClassificationExperiment(100)
357+
.CreateMulticlassClassificationExperiment(settings)
335358
.Execute(trainData, columnInference.ColumnInformation);
336359

337360
Assert.InRange(result.BestRun.ValidationMetrics.MicroAccuracy, 0.1, 0.9);
@@ -358,7 +381,7 @@ public void AutoFitRankingTest()
358381
// STEP 2: Run AutoML experiment
359382
var settings = new RankingExperimentSettings()
360383
{
361-
MaxExperimentTimeInSeconds = 5,
384+
MaxModels = 5,
362385
OptimizationMetricTruncationLevel = 3
363386
};
364387
var experiment = mlContext.Auto()

test/Microsoft.ML.AutoML.Tests/AutoMLExperimentTests.cs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,14 +151,21 @@ public async Task AutoMLExperiment_return_current_best_trial_when_ct_is_canceled
151151
public async Task AutoMLExperiment_finish_training_when_time_is_up_Async()
152152
{
153153
var context = new MLContext(1);
154+
context.Log += (o, e) =>
155+
{
156+
if (e.Source.StartsWith("AutoMLExperiment"))
157+
{
158+
this.Output.WriteLine(e.RawMessage);
159+
}
160+
};
154161

155162
var experiment = context.Auto().CreateExperiment();
156163
experiment.SetTrainingTimeInSeconds(5)
157164
.SetTrialRunner((serviceProvider) =>
158165
{
159166
var channel = serviceProvider.GetService<IChannel>();
160167
var settings = serviceProvider.GetService<AutoMLExperiment.AutoMLExperimentSettings>();
161-
return new DummyTrialRunner(settings, 1, channel);
168+
return new DummyTrialRunner(settings, 0, channel);
162169
})
163170
.SetTuner<RandomSearchTuner>();
164171

0 commit comments

Comments
 (0)