Skip to content

Commit d974a21

Browse files
authored
Add SamplingKeyColumnName to AutoMLExperiment (#6649)
1 parent ce9938d commit d974a21

File tree

4 files changed

+28
-3
lines changed

4 files changed

+28
-3
lines changed

src/Microsoft.ML.AutoML/API/AutoMLExperimentExtension.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,16 @@ public static AutoMLExperiment SetDataset(this AutoMLExperiment experiment, Trai
5757
/// </summary>
5858
/// <param name="experiment"><see cref="AutoMLExperiment"/></param>
5959
/// <param name="dataset">dataset for cross-validation split.</param>
60-
/// <param name="fold"></param>
60+
/// <param name="fold">number of cross-validation folds</param>
61+
/// <param name="samplingKeyColumnName">column name for sampling key</param>
6162
/// <returns><see cref="AutoMLExperiment"/></returns>
62-
public static AutoMLExperiment SetDataset(this AutoMLExperiment experiment, IDataView dataset, int fold = 10)
63+
public static AutoMLExperiment SetDataset(this AutoMLExperiment experiment, IDataView dataset, int fold = 10, string samplingKeyColumnName = null)
6364
{
6465
var datasetManager = new CrossValidateDatasetManager()
6566
{
6667
Dataset = dataset,
6768
Fold = fold,
69+
SamplingKeyColumnName = samplingKeyColumnName,
6870
};
6971

7072
experiment.ServiceCollection.AddSingleton<IDatasetManager>(datasetManager);

src/Microsoft.ML.AutoML/AutoMLExperiment/IDatasetManager.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ internal interface ICrossValidateDatasetManager
1717
int? Fold { get; set; }
1818

1919
IDataView Dataset { get; set; }
20+
21+
string SamplingKeyColumnName { get; set; }
2022
}
2123

2224
internal interface ITrainValidateDatasetManager
@@ -38,5 +40,6 @@ internal class CrossValidateDatasetManager : IDatasetManager, ICrossValidateData
3840
public IDataView Dataset { get; set; }
3941

4042
public int? Fold { get; set; }
43+
public string SamplingKeyColumnName { get; set; }
4144
}
4245
}

src/Microsoft.ML.AutoML/AutoMLExperiment/Runner/SweepablePipelineRunner.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ public TrialResult Run(TrialSettings settings)
4040
var mlnetPipeline = _pipeline.BuildFromOption(_mLContext, parameter);
4141
if (_datasetManager is ICrossValidateDatasetManager crossValidateDatasetManager)
4242
{
43-
var datasetSplit = _mLContext!.Data.CrossValidationSplit(crossValidateDatasetManager.Dataset, crossValidateDatasetManager.Fold ?? 5);
43+
var datasetSplit = _mLContext!.Data.CrossValidationSplit(crossValidateDatasetManager.Dataset, crossValidateDatasetManager.Fold ?? 5, crossValidateDatasetManager.SamplingKeyColumnName);
4444
var metrics = new List<double>();
4545
var models = new List<ITransformer>();
4646
foreach (var split in datasetSplit)

test/Microsoft.ML.AutoML.Tests/AutoMLExperimentTests.cs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,26 @@ public async Task AutoMLExperiment_Taxi_Fare_CV_5_Test()
354354
result.Metric.Should().BeGreaterThan(0.5);
355355
}
356356

357+
[Fact]
358+
public async Task AutoMLExperiment_Taxi_Fare_CV_5_SamplingKey_Test()
359+
{
360+
var context = new MLContext(1);
361+
var train = DatasetUtil.GetTaxiFareTrainDataView();
362+
var experiment = context.Auto().CreateExperiment();
363+
var label = DatasetUtil.TaxiFareLabel;
364+
var pipeline = context.Auto().Featurizer(train, excludeColumns: new[] { label })
365+
.Append(context.Auto().Regression(label, useLgbm: false, useSdca: false, useLbfgsPoissonRegression: false));
366+
367+
experiment.SetDataset(train, 5, "vendor_id")
368+
.SetRegressionMetric(RegressionMetric.RSquared, label)
369+
.SetPipeline(pipeline)
370+
.SetMaxModelToExplore(1);
371+
372+
var result = await experiment.RunAsync();
373+
result.Metric.Should().BeGreaterThan(0.2);
374+
result.Metric.Should().BeLessThan(0.5);
375+
}
376+
357377
[Fact]
358378
public void AutoMLExperiment_should_use_seed_from_context_if_provided()
359379
{

0 commit comments

Comments
 (0)