Skip to content

Commit db19715

Browse files
Add default search space for standard trainers (#6576)
* add default search space * add default search space for standard trainers * fix system.text.json build error * fix tests
1 parent 4c5aa85 commit db19715

File tree

41 files changed

+1088
-144
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1088
-144
lines changed

docs/samples/Microsoft.ML.Samples.GPU/Microsoft.ML.Samples.GPU.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444

4545
<ItemGroup>
4646
<PackageReference Include="Microsoft.ML.Onnx.TestModels" Version="$(MicrosoftMLOnnxTestModelsVersion)" />
47+
<PackageReference Include="System.Text.Json" Version="$(SystemTextJsonVersion)" />
4748
</ItemGroup>
4849
<ItemGroup Condition=" '$(OS)' == 'Windows_NT'">
4950
<PackageReference Include="SciSharp.TensorFlow.Redist-Windows-GPU" Version="$(TensorFlowVersion)" />

docs/samples/Microsoft.ML.Samples/Microsoft.ML.Samples.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -977,6 +977,7 @@
977977
<PackageReference Include="Microsoft.ML.Onnx.TestModels" Version="$(MicrosoftMLOnnxTestModelsVersion)" />
978978
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="$(TensorFlowVersion)" />
979979
<PackageReference Include="Microsoft.ML.OnnxRuntime" Version="$(MicrosoftMLOnnxRuntimeVersion)" />
980+
<PackageReference Include="System.Text.Json" Version="$(SystemTextJsonVersion)" />
980981
</ItemGroup>
981982

982983
<ItemGroup>

src/Microsoft.ML.Console/Microsoft.ML.Console.csproj

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
<Project Sdk="Microsoft.NET.Sdk">
22

33
<PropertyGroup>
4-
<TargetFramework>netcoreapp3.1</TargetFramework>
4+
<TargetFramework>net6.0</TargetFramework>
55
<OutputType>Exe</OutputType>
66
<AssemblyName>MML</AssemblyName>
77
<StartupObject>Microsoft.ML.Tools.Console.Console</StartupObject>
@@ -27,6 +27,7 @@
2727
<ProjectReference Include="..\Microsoft.ML.Sweeper\Microsoft.ML.Sweeper.csproj" />
2828
<ProjectReference Include="..\Microsoft.ML.TensorFlow\Microsoft.ML.TensorFlow.csproj" />
2929
<ProjectReference Include="..\Microsoft.ML.Transforms\Microsoft.ML.Transforms.csproj" />
30+
<PackageReference Include="System.Text.Json" Version="$(SystemTextJsonVersion)" />
3031

3132
<NativeAssemblyReference Include="FastTreeNative" />
3233
<NativeAssemblyReference Include="CpuMathNative" />
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Collections.Generic;
7+
using System.Text;
8+
using System.Text.Json;
9+
using System.Text.Json.Serialization;
10+
using Microsoft.ML.SearchSpace.Option;
11+
12+
namespace Microsoft.ML.SearchSpace.Converter
13+
{
14+
internal class ChoiceOptionConverter : JsonConverter<ChoiceOption>
15+
{
16+
class Schema
17+
{
18+
/// <summary>
19+
/// must be one of "int" | "float" | "double"
20+
/// </summary>
21+
[JsonPropertyName("default")]
22+
public object Default { get; set; }
23+
24+
[JsonPropertyName("choices")]
25+
public object[] Choices { get; set; }
26+
}
27+
28+
public override ChoiceOption Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
29+
{
30+
var schema = JsonSerializer.Deserialize<Schema>(ref reader, options);
31+
32+
return new ChoiceOption(schema.Choices, schema.Default);
33+
}
34+
35+
public override void Write(Utf8JsonWriter writer, ChoiceOption value, JsonSerializerOptions options)
36+
{
37+
var schema = new Schema
38+
{
39+
Choices = value.Choices,
40+
Default = value.SampleFromFeatureSpace(value.Default),
41+
};
42+
43+
JsonSerializer.Serialize(writer, schema, options);
44+
}
45+
}
46+
}
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Collections.Generic;
7+
using System.Text;
8+
using System.Text.Json;
9+
using System.Text.Json.Serialization;
10+
using Microsoft.ML.SearchSpace.Option;
11+
12+
namespace Microsoft.ML.SearchSpace.Converter
13+
{
14+
internal class NumericOptionConverter : JsonConverter<UniformNumericOption>
15+
{
16+
class Schema
17+
{
18+
/// <summary>
19+
/// must be one of "int" | "float" | "double"
20+
/// </summary>
21+
[JsonPropertyName("type")]
22+
public string Type { get; set; }
23+
24+
[JsonPropertyName("default")]
25+
public object Default { get; set; }
26+
27+
[JsonPropertyName("min")]
28+
public object Min { get; set; }
29+
30+
[JsonPropertyName("max")]
31+
public object Max { get; set; }
32+
33+
[JsonPropertyName("log_base")]
34+
public bool LogBase { get; set; }
35+
}
36+
37+
public override UniformNumericOption Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
38+
{
39+
var schema = JsonSerializer.Deserialize<Schema>(ref reader, options);
40+
41+
return schema.Type switch
42+
{
43+
"int" => new UniformIntOption(Convert.ToInt32(schema.Min), Convert.ToInt32(schema.Max), schema.LogBase, Convert.ToInt32(schema.Default)),
44+
"float" => new UniformSingleOption(Convert.ToSingle(schema.Min), Convert.ToSingle(schema.Max), schema.LogBase, Convert.ToSingle(schema.Default)),
45+
"double" => new UniformDoubleOption(Convert.ToDouble(schema.Min), Convert.ToDouble(schema.Max), schema.LogBase, Convert.ToDouble(schema.Default)),
46+
_ => throw new ArgumentException($"unknown schema type: {schema.Type}"),
47+
};
48+
}
49+
50+
public override void Write(Utf8JsonWriter writer, UniformNumericOption value, JsonSerializerOptions options)
51+
{
52+
var schema = value switch
53+
{
54+
UniformIntOption intOption => new Schema
55+
{
56+
Type = "int",
57+
Default = intOption.SampleFromFeatureSpace(intOption.Default).AsType<int>(),
58+
Min = Convert.ToInt32(intOption.Min),
59+
Max = Convert.ToInt32(intOption.Max),
60+
LogBase = intOption.LogBase,
61+
},
62+
UniformDoubleOption doubleOption => new Schema
63+
{
64+
Type = "double",
65+
Default = doubleOption.SampleFromFeatureSpace(doubleOption.Default).AsType<double>(),
66+
Min = doubleOption.Min,
67+
Max = doubleOption.Max,
68+
LogBase = doubleOption.LogBase,
69+
},
70+
UniformSingleOption singleOption => new Schema
71+
{
72+
Type = "float",
73+
Default = singleOption.SampleFromFeatureSpace(singleOption.Default).AsType<Single>(),
74+
Min = Convert.ToSingle(singleOption.Min),
75+
Max = Convert.ToSingle(singleOption.Max),
76+
LogBase = singleOption.LogBase,
77+
},
78+
_ => throw new ArgumentException("unknown type"),
79+
};
80+
81+
JsonSerializer.Serialize(writer, schema, options);
82+
}
83+
}
84+
}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Collections.Generic;
7+
using System.Text;
8+
using System.Text.Json;
9+
using System.Text.Json.Serialization;
10+
using Microsoft.ML.SearchSpace.Option;
11+
12+
namespace Microsoft.ML.SearchSpace.Converter
13+
{
14+
internal class OptionConverter : JsonConverter<OptionBase>
15+
{
16+
public override OptionBase Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
17+
{
18+
try
19+
{
20+
return JsonSerializer.Deserialize<SearchSpace>(ref reader, options);
21+
}
22+
catch (Exception)
23+
{
24+
// try choice option
25+
}
26+
27+
try
28+
{
29+
return JsonSerializer.Deserialize<ChoiceOption>(ref reader, options);
30+
}
31+
catch (Exception)
32+
{
33+
// try numeric option
34+
}
35+
36+
try
37+
{
38+
return JsonSerializer.Deserialize<UniformNumericOption>(ref reader, options);
39+
}
40+
catch (Exception)
41+
{
42+
throw new ArgumentException("unknown option type");
43+
}
44+
}
45+
46+
public override void Write(Utf8JsonWriter writer, OptionBase value, JsonSerializerOptions options)
47+
{
48+
if (value is SearchSpace ss)
49+
{
50+
JsonSerializer.Serialize(writer, ss, options);
51+
}
52+
else if (value is ChoiceOption choiceOption)
53+
{
54+
JsonSerializer.Serialize(writer, choiceOption, options);
55+
}
56+
else if (value is UniformNumericOption uniformNumericOption)
57+
{
58+
JsonSerializer.Serialize(writer, uniformNumericOption, options);
59+
}
60+
else
61+
{
62+
throw new ArgumentException("unknown option type");
63+
}
64+
}
65+
}
66+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Collections.Generic;
7+
using System.Text;
8+
using System.Text.Json;
9+
using System.Text.Json.Serialization;
10+
using Microsoft.ML.SearchSpace.Option;
11+
12+
namespace Microsoft.ML.SearchSpace.Converter
13+
{
14+
internal class SearchSpaceConverter : JsonConverter<SearchSpace>
15+
{
16+
public override SearchSpace Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
17+
{
18+
var optionKVPairs = JsonSerializer.Deserialize<Dictionary<string, OptionBase>>(ref reader, options);
19+
20+
return new SearchSpace(optionKVPairs);
21+
}
22+
23+
public override void Write(Utf8JsonWriter writer, SearchSpace value, JsonSerializerOptions options)
24+
{
25+
JsonSerializer.Serialize<IDictionary<string, OptionBase>>(value, options);
26+
}
27+
}
28+
}

src/Microsoft.ML.SearchSpace/Microsoft.ML.SearchSpace.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
<PropertyGroup>
44
<TargetFramework>netstandard2.0</TargetFramework>
5-
<IncludeInPackage>Microsoft.ML.AutoML</IncludeInPackage>
5+
<IncludeInPackage>Microsoft.ML.Core</IncludeInPackage>
66
<DisableImplicitNamespaceImports>true</DisableImplicitNamespaceImports>
77
<NoWarn>MSML_ContractsCheckMessageNotLiteralOrIdentifier</NoWarn>
88
<LangVersion>9.0</LangVersion>

src/Microsoft.ML.SearchSpace/Option/ChoiceOption.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
using System;
66
using System.Diagnostics.Contracts;
77
using System.Linq;
8+
using System.Text.Json.Serialization;
9+
using Microsoft.ML.SearchSpace.Converter;
810

911
#nullable enable
1012

@@ -13,6 +15,7 @@ namespace Microsoft.ML.SearchSpace.Option
1315
/// <summary>
1416
/// This class represent option for discrete value, such as string, enum, etc..
1517
/// </summary>
18+
[JsonConverter(typeof(ChoiceOptionConverter))]
1619
public sealed class ChoiceOption : OptionBase
1720
{
1821
private readonly UniformSingleOption _option;

src/Microsoft.ML.SearchSpace/Option/NestOption.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
namespace Microsoft.ML.SearchSpace.Option
1212
{
1313
/// <summary>
14-
/// This class represent nest option, which is an option that contains other options, like <see cref="ChoiceOption"/>, <see cref="UniformNumericOption"/> or even <see cref="NestOption"/> itself.
14+
/// This class represent nest option, which is an option that contains other options, like <see cref="ChoiceOption"/>, <see cref="UniformNumericOption"/> or even <see cref="SearchSpace"/> itself.
1515
/// </summary>
16-
public sealed class NestOption : OptionBase, IDictionary<string, OptionBase>
16+
public sealed class SearchSpace : OptionBase, IDictionary<string, OptionBase>
1717
{
1818
private readonly Dictionary<string, OptionBase> _options = new Dictionary<string, OptionBase>();
1919

0 commit comments

Comments
 (0)