Skip to content

Commit 4c5aa85

Browse files
fix refit (#6572)
1 parent fe2a442 commit 4c5aa85

File tree

4 files changed

+41
-12
lines changed

4 files changed

+41
-12
lines changed

src/Microsoft.ML.AutoML/API/BinaryClassificationExperiment.cs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -341,12 +341,14 @@ internal class BinaryClassificationRunner : ITrialRunner
341341
{
342342
private MLContext _context;
343343
private readonly IDatasetManager _datasetManager;
344+
private readonly IMLContextManager _contextManager;
344345
private readonly IMetricManager _metricManager;
345346
private readonly SweepablePipeline _pipeline;
346347
private readonly Random _rnd;
347-
public BinaryClassificationRunner(MLContext context, IDatasetManager datasetManager, IMetricManager metricManager, SweepablePipeline pipeline, AutoMLExperiment.AutoMLExperimentSettings settings)
348+
public BinaryClassificationRunner(IMLContextManager contextManager, IDatasetManager datasetManager, IMetricManager metricManager, SweepablePipeline pipeline, AutoMLExperiment.AutoMLExperimentSettings settings)
348349
{
349-
_context = context;
350+
_context = contextManager.CreateMLContext();
351+
_contextManager = contextManager;
350352
_datasetManager = datasetManager;
351353
_metricManager = metricManager;
352354
_pipeline = pipeline;
@@ -365,6 +367,10 @@ public TrialResult Run(TrialSettings settings)
365367
{
366368
var parameter = settings.Parameter[AutoMLExperiment.PipelineSearchspaceName];
367369
var pipeline = _pipeline.BuildFromOption(_context, parameter);
370+
// _context will be cancelled after training. So returned pipeline need to be created on a
371+
// new MLContext.
372+
var refitContext = _contextManager.CreateMLContext();
373+
var refitPipeline = _pipeline.BuildFromOption(refitContext, parameter);
368374
if (_datasetManager is ICrossValidateDatasetManager datasetManager)
369375
{
370376
var stopWatch = new Stopwatch();
@@ -396,7 +402,7 @@ public TrialResult Run(TrialSettings settings)
396402
DurationInMilliseconds = stopWatch.ElapsedMilliseconds,
397403
Metrics = res.Metrics,
398404
CrossValidationMetrics = metrics,
399-
Pipeline = pipeline,
405+
Pipeline = refitPipeline,
400406
};
401407
}
402408

@@ -430,7 +436,7 @@ public TrialResult Run(TrialSettings settings)
430436
TrialSettings = settings,
431437
DurationInMilliseconds = stopWatch.ElapsedMilliseconds,
432438
Metrics = metrics,
433-
Pipeline = pipeline,
439+
Pipeline = refitPipeline,
434440
};
435441
}
436442
}

src/Microsoft.ML.AutoML/API/MulticlassClassificationExperiment.cs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -343,12 +343,14 @@ internal class MulticlassClassificationRunner : ITrialRunner
343343
private MLContext _context;
344344
private readonly IDatasetManager _datasetManager;
345345
private readonly IMetricManager _metricManager;
346+
private readonly IMLContextManager _contextManager;
346347
private readonly SweepablePipeline _pipeline;
347348
private readonly Random _rnd;
348349

349-
public MulticlassClassificationRunner(MLContext context, IDatasetManager datasetManager, IMetricManager metricManager, SweepablePipeline pipeline, AutoMLExperiment.AutoMLExperimentSettings settings)
350+
public MulticlassClassificationRunner(IMLContextManager contextManager, IDatasetManager datasetManager, IMetricManager metricManager, SweepablePipeline pipeline, AutoMLExperiment.AutoMLExperimentSettings settings)
350351
{
351-
_context = context;
352+
_context = contextManager.CreateMLContext();
353+
_contextManager = contextManager;
352354
_datasetManager = datasetManager;
353355
_metricManager = metricManager;
354356
_pipeline = pipeline;
@@ -361,6 +363,8 @@ public TrialResult Run(TrialSettings settings)
361363
{
362364
var parameter = settings.Parameter[AutoMLExperiment.PipelineSearchspaceName];
363365
var pipeline = _pipeline.BuildFromOption(_context, parameter);
366+
var refitContext = _contextManager.CreateMLContext();
367+
var refitPipeline = _pipeline.BuildFromOption(refitContext, parameter);
364368
if (_datasetManager is ICrossValidateDatasetManager datasetManager)
365369
{
366370
var stopWatch = new Stopwatch();
@@ -394,7 +398,7 @@ public TrialResult Run(TrialSettings settings)
394398
DurationInMilliseconds = stopWatch.ElapsedMilliseconds,
395399
Metrics = res.Metrics,
396400
CrossValidationMetrics = metrics,
397-
Pipeline = pipeline,
401+
Pipeline = refitPipeline,
398402
};
399403
}
400404

@@ -428,7 +432,7 @@ public TrialResult Run(TrialSettings settings)
428432
TrialSettings = settings,
429433
DurationInMilliseconds = stopWatch.ElapsedMilliseconds,
430434
Metrics = metrics,
431-
Pipeline = pipeline,
435+
Pipeline = refitPipeline,
432436
};
433437
}
434438
}

src/Microsoft.ML.AutoML/API/RegressionExperiment.cs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -363,12 +363,14 @@ internal class RegressionTrialRunner : ITrialRunner
363363
private MLContext _context;
364364
private readonly IDatasetManager _datasetManager;
365365
private readonly IMetricManager _metricManager;
366+
private readonly IMLContextManager _contextManager;
366367
private readonly SweepablePipeline _pipeline;
367368
private readonly Random _rnd;
368369

369-
public RegressionTrialRunner(MLContext context, IDatasetManager datasetManager, IMetricManager metricManager, SweepablePipeline pipeline, AutoMLExperiment.AutoMLExperimentSettings settings)
370+
public RegressionTrialRunner(IMLContextManager contextManager, IDatasetManager datasetManager, IMetricManager metricManager, SweepablePipeline pipeline, AutoMLExperiment.AutoMLExperimentSettings settings)
370371
{
371-
_context = context;
372+
_context = contextManager.CreateMLContext();
373+
_contextManager = contextManager;
372374
_datasetManager = datasetManager;
373375
_metricManager = metricManager;
374376
_pipeline = pipeline;
@@ -388,6 +390,8 @@ public Task<TrialResult> RunAsync(TrialSettings settings, CancellationToken ct)
388390
{
389391
var parameter = settings.Parameter[AutoMLExperiment.PipelineSearchspaceName];
390392
var pipeline = _pipeline.BuildFromOption(_context, parameter);
393+
var refitContext = _contextManager.CreateMLContext();
394+
var refitPipeline = _pipeline.BuildFromOption(refitContext, parameter);
391395
if (_datasetManager is ICrossValidateDatasetManager datasetManager)
392396
{
393397
var stopWatch = new Stopwatch();
@@ -420,7 +424,7 @@ public Task<TrialResult> RunAsync(TrialSettings settings, CancellationToken ct)
420424
DurationInMilliseconds = stopWatch.ElapsedMilliseconds,
421425
Metrics = res.Metrics,
422426
CrossValidationMetrics = metrics,
423-
Pipeline = pipeline,
427+
Pipeline = refitPipeline,
424428
} as TrialResult);
425429
}
426430

@@ -453,7 +457,7 @@ public Task<TrialResult> RunAsync(TrialSettings settings, CancellationToken ct)
453457
TrialSettings = settings,
454458
DurationInMilliseconds = stopWatch.ElapsedMilliseconds,
455459
Metrics = res,
456-
Pipeline = pipeline,
460+
Pipeline = refitPipeline,
457461
} as TrialResult);
458462
}
459463
}

test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,10 @@ public void AutoFit_UCI_Adult_CrossValidation_10_Test()
113113
Assert.True(result.BestRun.Results.Select(x => x.ValidationMetrics.Accuracy).Min() > 0.70);
114114
Assert.NotNull(result.BestRun.Estimator);
115115
Assert.NotNull(result.BestRun.TrainerName);
116+
117+
// test refit
118+
var model = result.BestRun.Estimator.Fit(trainData);
119+
Assert.NotNull(model);
116120
}
117121

118122
[Fact]
@@ -214,6 +218,10 @@ public void AutoFit_Taxi_Fare_Test()
214218
Assert.True(result.BestRun.ValidationMetrics.RSquared > 0.70);
215219
Assert.NotNull(result.BestRun.Estimator);
216220
Assert.NotNull(result.BestRun.TrainerName);
221+
222+
// verify refit
223+
var model = result.BestRun.Estimator.Fit(context.Data.TakeRows(dataset, 1000));
224+
Assert.NotNull(model);
217225
}
218226

219227
[Theory]
@@ -253,6 +261,10 @@ public void AutoFitMultiTest(bool useNumberOfCVFolds)
253261
result.BestRun.Results.First().ValidationMetrics.MicroAccuracy.Should().BeGreaterThan(0.7);
254262
var scoredData = result.BestRun.Results.First().Model.Transform(trainData);
255263
Assert.Equal(NumberDataViewType.Single, scoredData.Schema[DefaultColumnNames.PredictedLabel].Type);
264+
265+
// test refit
266+
var model = result.BestRun.Estimator.Fit(trainData);
267+
Assert.NotNull(model);
256268
}
257269
else
258270
{
@@ -281,6 +293,9 @@ public void AutoFitMultiTest(bool useNumberOfCVFolds)
281293
Assert.True(result.BestRun.ValidationMetrics.MicroAccuracy >= 0.7);
282294
var scoredData = result.BestRun.Model.Transform(trainData);
283295
Assert.Equal(NumberDataViewType.Single, scoredData.Schema[DefaultColumnNames.PredictedLabel].Type);
296+
297+
var model = result.BestRun.Estimator.Fit(trainData);
298+
Assert.NotNull(model);
284299
}
285300
}
286301

0 commit comments

Comments
 (0)