Skip to content

Commit b1cb564

Browse files
add SentenceSimilarity sweepable estimator in AutoML (#6445)
1 parent 42788c4 commit b1cb564

File tree

6 files changed

+80
-14
lines changed

6 files changed

+80
-14
lines changed

src/Microsoft.ML.AutoML/CodeGen/estimator-schema.json

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@
7171
"DnnFeaturizerImage",
7272
"Naive",
7373
"ForecastBySsa",
74-
"TextClassification"
74+
"TextClassifcation",
75+
"SentenceSimilarity"
7576
]
7677
},
7778
"nugetDependencies": {
@@ -109,6 +110,7 @@
109110
"Microsoft.ML.Vision",
110111
"Microsoft.ML.Transforms.Image",
111112
"Microsoft.ML.Trainers.FastTree",
113+
"Microsoft.ML.TorchSharp",
112114
"Microsoft.ML.Trainers.LightGbm"
113115
]
114116
}

src/Microsoft.ML.AutoML/CodeGen/search-space-schema.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,8 @@
126126
"image_classification_option",
127127
"matrix_factorization_option",
128128
"dnn_featurizer_image_option",
129-
"text_classification_option"
129+
"text_classification_option",
130+
"sentence_similarity_option"
130131
]
131132
},
132133
"option_name": {
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
{
2+
"$schema": "./search-space-schema.json#",
3+
"name": "sentence_similarity_option",
4+
"search_space": [
5+
{
6+
"name": "LabelColumnName",
7+
"type": "string",
8+
"default": "Label"
9+
},
10+
{
11+
"name": "Sentence1ColumnName",
12+
"type": "string",
13+
"default": "Sentence1"
14+
},
15+
{
16+
"name": "Sentence2ColumnName",
17+
"type": "string"
18+
},
19+
{
20+
"name": "ScoreColumnName",
21+
"type": "string",
22+
"default": "Score"
23+
},
24+
{
25+
"name": "BatchSize",
26+
"type": "integer",
27+
"default": 32
28+
},
29+
{
30+
"name": "MaxEpochs",
31+
"type": "integer",
32+
"default": 10
33+
},
34+
{
35+
"name": "Architecture",
36+
"type": "bertArchitecture",
37+
"default": "BertArchitecture.Roberta"
38+
}
39+
]
40+
}

src/Microsoft.ML.AutoML/CodeGen/trainer-estimators.json

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@
306306
"argumentType": "boolean"
307307
}
308308
],
309-
"nugetDependencies": ["Microsoft.ML"],
309+
"nugetDependencies": [ "Microsoft.ML" ],
310310
"usingStatements": [ "Microsoft.ML", "Microsoft.ML.Trainers" ],
311311
"searchOption": "lbfgs_option"
312312
},
@@ -514,20 +514,17 @@
514514
{
515515
"functionName": "TextClassification",
516516
"estimatorTypes": [ "MultiClassification" ],
517-
"arguments": [
518-
{
519-
"argumentName": "labelColumnName",
520-
"argumentType": "string"
521-
},
522-
{
523-
"argumentName": "sentence1ColumnName",
524-
"argumentType": "string"
525-
}
526-
],
527517
"nugetDependencies": [ "Microsoft.ML", "Microsoft.ML.TorchSharp" ],
528518
"usingStatements": [ "Microsoft.ML", "Microsoft.ML.Trainers", "Microsoft.ML.TorchSharp" ],
529519
"searchOption": "text_classification_option"
530520
},
521+
{
522+
"functionName": "SentenceSimilarity",
523+
"estimatorTypes": [ "Regression" ],
524+
"nugetDependencies": [ "Microsoft.ML", "Microsoft.ML.TorchSharp" ],
525+
"usingStatements": [ "Microsoft.ML", "Microsoft.ML.Trainers", "Microsoft.ML.TorchSharp" ],
526+
"searchOption": "sentence_similarity_option"
527+
},
531528
{
532529
"functionName": "ForecastBySsa",
533530
"estimatorTypes": [ "Forecasting" ],

src/Microsoft.ML.AutoML/Microsoft.ML.AutoML.csproj

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@
6666
<AdditionalFiles Include="CodeGen\*search_space.json" />
6767
<AdditionalFiles Include="CodeGen\code_gen_flag.json" />
6868
<AdditionalFiles Include="CodeGen\*-estimators.json" />
69-
<AdditionalFiles Include="CodeGen\code_gen_flag.json" />
7069
</ItemGroup>
7170

7271
<Target DependsOnTargets="ResolveReferences" Name="CopyProjectReferencesToPackage">
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Collections.Generic;
7+
using System.Reflection;
8+
using System.Text;
9+
using Microsoft.ML.TorchSharp;
10+
11+
namespace Microsoft.ML.AutoML.CodeGen
12+
{
13+
internal partial class SentenceSimilarityRegression
14+
{
15+
public override IEstimator<ITransformer> BuildFromOption(MLContext context, SentenceSimilarityOption param)
16+
{
17+
return context.Regression.Trainers.SentenceSimilarity(
18+
labelColumnName: param.LabelColumnName,
19+
sentence1ColumnName: param.Sentence1ColumnName,
20+
scoreColumnName: param.ScoreColumnName,
21+
sentence2ColumnName: param.Sentence2ColumnName,
22+
batchSize: param.BatchSize,
23+
maxEpochs: param.MaxEpochs,
24+
architecture: param.Architecture);
25+
}
26+
}
27+
}

0 commit comments

Comments
 (0)