Skip to content

Commit b2fa350

Browse files
authored
feat: add base classes for ML and refine code base (#1031)
1 parent c89bd28 commit b2fa350

File tree

44 files changed

+1793
-212
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+1793
-212
lines changed

src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/FeatureBaseTests.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ protected FeatureBaseTests(SparkFixture fixture)
2525
/// <param name="paramName">The name of a parameter that can be set on this object</param>
2626
/// <param name="paramValue">A parameter value that can be set on this object</param>
2727
public void TestFeatureBase(
28-
FeatureBase<T> testObject,
28+
Params testObject,
2929
string paramName,
3030
object paramValue)
3131
{
@@ -37,8 +37,8 @@ public void TestFeatureBase(
3737
Assert.Equal(param.Parent, testObject.Uid());
3838

3939
Assert.NotEmpty(testObject.ExplainParam(param));
40-
testObject.Set(param, paramValue);
41-
Assert.IsAssignableFrom<Identifiable>(testObject.Clear(param));
40+
testObject.Set<T>(param, paramValue);
41+
Assert.IsAssignableFrom<Identifiable>(testObject.Clear<T>(param));
4242

4343
Assert.IsType<string>(testObject.Uid());
4444
}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System.IO;
6+
using Microsoft.Spark.ML.Feature;
7+
using Microsoft.Spark.Sql;
8+
using Microsoft.Spark.UnitTest.TestUtils;
9+
using Microsoft.Spark.Sql.Types;
10+
using Xunit;
11+
12+
namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature
13+
{
14+
[Collection("Spark E2E Tests")]
15+
public class PipelineModelTests : FeatureBaseTests<PipelineModel>
16+
{
17+
private readonly SparkSession _spark;
18+
19+
public PipelineModelTests(SparkFixture fixture) : base(fixture)
20+
{
21+
_spark = fixture.Spark;
22+
}
23+
24+
/// <summary>
25+
/// Create a <see cref="PipelineModel"/> and test the
26+
/// available methods.
27+
/// </summary>
28+
[Fact]
29+
public void TestPipelineModelTransform()
30+
{
31+
var expectedSplits =
32+
new double[] { double.MinValue, 0.0, 10.0, 50.0, double.MaxValue };
33+
34+
string expectedHandle = "skip";
35+
string expectedUid = "uid";
36+
string expectedInputCol = "input_col";
37+
string expectedOutputCol = "output_col";
38+
39+
var bucketizer = new Bucketizer(expectedUid);
40+
bucketizer.SetInputCol(expectedInputCol)
41+
.SetOutputCol(expectedOutputCol)
42+
.SetHandleInvalid(expectedHandle)
43+
.SetSplits(expectedSplits);
44+
45+
var stages = new JavaTransformer[] {
46+
bucketizer
47+
};
48+
49+
PipelineModel pipelineModel = new PipelineModel("randomUID", stages);
50+
51+
DataFrame input = _spark.Sql("SELECT ID as input_col from range(100)");
52+
53+
DataFrame output = pipelineModel.Transform(input);
54+
Assert.Contains(output.Schema().Fields, (f => f.Name == expectedOutputCol));
55+
56+
Assert.Equal(expectedInputCol, bucketizer.GetInputCol());
57+
Assert.Equal(expectedOutputCol, bucketizer.GetOutputCol());
58+
Assert.Equal(expectedSplits, bucketizer.GetSplits());
59+
60+
Assert.IsType<StructType>(pipelineModel.TransformSchema(input.Schema()));
61+
Assert.IsType<DataFrame>(output);
62+
63+
using (var tempDirectory = new TemporaryDirectory())
64+
{
65+
string savePath = Path.Join(tempDirectory.Path, "pipelineModel");
66+
pipelineModel.Save(savePath);
67+
68+
PipelineModel loadedPipelineModel = PipelineModel.Load(savePath);
69+
Assert.Equal(pipelineModel.Uid(), loadedPipelineModel.Uid());
70+
71+
string writePath = Path.Join(tempDirectory.Path, "pipelineModelWithWrite");
72+
pipelineModel.Write().Save(writePath);
73+
74+
PipelineModel loadedPipelineModelWithRead = pipelineModel.Read().Load(writePath);
75+
Assert.Equal(pipelineModel.Uid(), loadedPipelineModelWithRead.Uid());
76+
}
77+
}
78+
}
79+
}
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System.IO;
6+
using Microsoft.Spark.ML.Feature;
7+
using Microsoft.Spark.Sql;
8+
using Microsoft.Spark.UnitTest.TestUtils;
9+
using Microsoft.Spark.Sql.Types;
10+
using Xunit;
11+
12+
namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature
13+
{
14+
[Collection("Spark E2E Tests")]
15+
public class PipelineTests : FeatureBaseTests<Pipeline>
16+
{
17+
private readonly SparkSession _spark;
18+
19+
public PipelineTests(SparkFixture fixture) : base(fixture)
20+
{
21+
_spark = fixture.Spark;
22+
}
23+
24+
/// <summary>
25+
/// Create a <see cref="Pipeline"/> and test the
26+
/// available methods. Test the FeatureBase methods
27+
/// using <see cref="TestFeatureBase"/>.
28+
/// </summary>
29+
[Fact]
30+
public void TestPipeline()
31+
{
32+
var stages = new JavaPipelineStage[] {
33+
new Bucketizer(),
34+
new CountVectorizer()
35+
};
36+
37+
Pipeline pipeline = new Pipeline()
38+
.SetStages(stages);
39+
JavaPipelineStage[] returnStages = pipeline.GetStages();
40+
41+
Assert.Equal(stages[0].Uid(), returnStages[0].Uid());
42+
Assert.Equal(stages[0].ToString(), returnStages[0].ToString());
43+
Assert.Equal(stages[1].Uid(), returnStages[1].Uid());
44+
Assert.Equal(stages[1].ToString(), returnStages[1].ToString());
45+
46+
using (var tempDirectory = new TemporaryDirectory())
47+
{
48+
string savePath = Path.Join(tempDirectory.Path, "pipeline");
49+
pipeline.Save(savePath);
50+
51+
Pipeline loadedPipeline = Pipeline.Load(savePath);
52+
Assert.Equal(pipeline.Uid(), loadedPipeline.Uid());
53+
}
54+
55+
TestFeatureBase(pipeline, "stages", stages);
56+
}
57+
58+
/// <summary>
59+
/// Create a <see cref="Pipeline"/> and test the
60+
/// fit and read/write methods.
61+
/// </summary>
62+
[Fact]
63+
public void TestPipelineFit()
64+
{
65+
DataFrame input = _spark.Sql("SELECT array('hello', 'I', 'AM', 'a', 'string', 'TO', " +
66+
"'TOKENIZE') as input from range(100)");
67+
68+
const string inputColumn = "input";
69+
const string outputColumn = "output";
70+
const double minDf = 1;
71+
const double minTf = 10;
72+
const int vocabSize = 10000;
73+
74+
CountVectorizer countVectorizer = new CountVectorizer()
75+
.SetInputCol(inputColumn)
76+
.SetOutputCol(outputColumn)
77+
.SetMinDF(minDf)
78+
.SetMinTF(minTf)
79+
.SetVocabSize(vocabSize);
80+
81+
var stages = new JavaPipelineStage[] {
82+
countVectorizer
83+
};
84+
85+
Pipeline pipeline = new Pipeline().SetStages(stages);
86+
PipelineModel pipelineModel = pipeline.Fit(input);
87+
88+
DataFrame output = pipelineModel.Transform(input);
89+
90+
Assert.IsType<StructType>(pipelineModel.TransformSchema(input.Schema()));
91+
Assert.IsType<DataFrame>(output);
92+
93+
using (var tempDirectory = new TemporaryDirectory())
94+
{
95+
string savePath = Path.Join(tempDirectory.Path, "pipeline");
96+
pipeline.Save(savePath);
97+
98+
Pipeline loadedPipeline = Pipeline.Load(savePath);
99+
Assert.Equal(pipeline.Uid(), loadedPipeline.Uid());
100+
101+
string writePath = Path.Join(tempDirectory.Path, "pipelineWithWrite");
102+
pipeline.Write().Save(writePath);
103+
104+
Pipeline loadedPipelineWithRead = pipeline.Read().Load(writePath);
105+
Assert.Equal(pipeline.Uid(), loadedPipelineWithRead.Uid());
106+
}
107+
108+
TestFeatureBase(pipeline, "stages", stages);
109+
}
110+
}
111+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.Spark.Interop;
6+
using Microsoft.Spark.Interop.Internal.Java.Util;
7+
8+
namespace System
9+
{
10+
/// <summary>
11+
/// ArrayExtensions host custom extension methods for the
12+
/// dotnet base class array T[].
13+
/// </summary>
14+
public static class ArrayExtensions
15+
{
16+
/// <summary>
17+
/// A custom extension method that helps transform from dotnet
18+
/// array of type T to java.util.ArrayList.
19+
/// </summary>
20+
/// <param name="array">an array instance</param>
21+
/// <typeparam name="T">elements type of param array</typeparam>
22+
/// <returns><see cref="ArrayList"/></returns>
23+
internal static ArrayList ToJavaArrayList<T>(this T[] array)
24+
{
25+
var arrayList = new ArrayList(SparkEnvironment.JvmBridge);
26+
foreach (T item in array)
27+
{
28+
arrayList.Add(item);
29+
}
30+
return arrayList;
31+
}
32+
}
33+
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.Spark.Interop;
6+
using Microsoft.Spark.Interop.Internal.Java.Util;
7+
8+
namespace System.Collections.Generic
9+
{
10+
public static class Dictionary
11+
{
12+
/// <summary>
13+
/// A custom extension method that helps transform from dotnet
14+
/// Dictionary&lt;string, string&gt; to java.util.HashMap.
15+
/// </summary>
16+
/// <param name="dictionary">a Dictionary instance</param>
17+
/// <returns><see cref="HashMap"/></returns>
18+
internal static HashMap ToJavaHashMap(this Dictionary<string, string> dictionary)
19+
{
20+
var hashMap = new HashMap(SparkEnvironment.JvmBridge);
21+
foreach (KeyValuePair<string, string> item in dictionary)
22+
{
23+
hashMap.Put(item.Key, item.Value);
24+
}
25+
return hashMap;
26+
}
27+
28+
/// <summary>
29+
/// A custom extension method that helps transform from dotnet
30+
/// Dictionary&lt;string, object&gt; to java.util.HashMap.
31+
/// </summary>
32+
/// <param name="dictionary">a Dictionary instance</param>
33+
/// <returns><see cref="HashMap"/></returns>
34+
internal static HashMap ToJavaHashMap(this Dictionary<string, object> dictionary)
35+
{
36+
var hashMap = new HashMap(SparkEnvironment.JvmBridge);
37+
foreach (KeyValuePair<string, object> item in dictionary)
38+
{
39+
hashMap.Put(item.Key, item.Value);
40+
}
41+
return hashMap;
42+
}
43+
}
44+
}
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.Spark.Interop.Ipc;
6+
7+
namespace Microsoft.Spark.Interop.Internal.Java.Util
8+
{
9+
/// <summary>
10+
/// HashMap class represents a <c>java.util.HashMap</c> object.
11+
/// </summary>
12+
internal sealed class HashMap : IJvmObjectReferenceProvider
13+
{
14+
/// <summary>
15+
/// Create a <c>java.util.HashMap</c> JVM object
16+
/// </summary>
17+
/// <param name="jvm">JVM bridge to use</param>
18+
internal HashMap(IJvmBridge jvm) =>
19+
Reference = jvm.CallConstructor("java.util.HashMap");
20+
21+
public JvmObjectReference Reference { get; private set; }
22+
23+
/// <summary>
24+
/// Associates the specified value with the specified key in this map.
25+
/// If the map previously contained a mapping for the key, the old value is replaced.
26+
/// </summary>
27+
/// <param name="key">key with which the specified value is to be associated</param>
28+
/// <param name="value">value to be associated with the specified key</param>
29+
internal void Put(object key, object value) =>
30+
Reference.Invoke("put", key, value);
31+
32+
/// <summary>
33+
/// Returns the value to which the specified key is mapped,
34+
/// or null if this map contains no mapping for the key.
35+
/// </summary>
36+
/// <param name="key">value whose presence in this map is to be tested</param>
37+
/// <return>value associated with the specified key</return>
38+
internal object Get(object key) =>
39+
Reference.Invoke("get", key);
40+
41+
/// <summary>
42+
/// Returns true if this map maps one or more keys to the specified value.
43+
/// </summary>
44+
/// <param name="value">The HashMap key</param>
45+
/// <return>true if this map maps one or more keys to the specified value</return>
46+
internal bool ContainsValue(object value) =>
47+
(bool)Reference.Invoke("containsValue", value);
48+
49+
/// <summary>
50+
/// Returns an array of the keys contained in this map.
51+
/// </summary>
52+
/// <return>An array of object hosting the keys contained in the map</return>
53+
internal object[] Keys()
54+
{
55+
var jvmObject = (JvmObjectReference)Reference.Invoke("keySet");
56+
var result = (object[])jvmObject.Invoke("toArray");
57+
return result;
58+
}
59+
}
60+
}

0 commit comments

Comments
 (0)