Skip to content

Commit ec81890

Browse files
authored
Introduce a base class for Spark.ML.Features (#574)
1 parent 2a597d8 commit ec81890

File tree

9 files changed

+222
-415
lines changed

9 files changed

+222
-415
lines changed

src/csharp/Microsoft.Spark/ML/Feature/Bucketizer.cs

Lines changed: 40 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -20,34 +20,29 @@ namespace Microsoft.Spark.ML.Feature
2020
/// will be thrown. The splits parameter is only used for single column usage, and splitsArray
2121
/// is for multiple columns.
2222
/// </summary>
23-
public class Bucketizer : IJvmObjectReferenceProvider
23+
public class Bucketizer : FeatureBase<Bucketizer>, IJvmObjectReferenceProvider
2424
{
2525
private static readonly string s_bucketizerClassName =
2626
"org.apache.spark.ml.feature.Bucketizer";
2727

28-
private readonly JvmObjectReference _jvmObject;
29-
3028
/// <summary>
3129
/// Create a <see cref="Bucketizer"/> without any parameters
3230
/// </summary>
33-
public Bucketizer()
31+
public Bucketizer() : base(s_bucketizerClassName)
3432
{
35-
_jvmObject = SparkEnvironment.JvmBridge.CallConstructor(s_bucketizerClassName);
3633
}
3734

3835
/// <summary>
3936
/// Create a <see cref="Bucketizer"/> with a UID that is used to give the
4037
/// <see cref="Bucketizer"/> a unique ID
4138
/// </summary>
4239
/// <param name="uid">An immutable unique ID for the object and its derivatives.</param>
43-
public Bucketizer(string uid)
40+
public Bucketizer(string uid) : base(s_bucketizerClassName, uid)
4441
{
45-
_jvmObject = SparkEnvironment.JvmBridge.CallConstructor(s_bucketizerClassName, uid);
4642
}
4743

48-
internal Bucketizer(JvmObjectReference jvmObject)
44+
internal Bucketizer(JvmObjectReference jvmObject) : base(jvmObject)
4945
{
50-
_jvmObject = jvmObject;
5146
}
5247

5348
JvmObjectReference IJvmObjectReferenceProvider.Reference => _jvmObject;
@@ -56,11 +51,8 @@ internal Bucketizer(JvmObjectReference jvmObject)
5651
/// Gets the splits that were set using SetSplits
5752
/// </summary>
5853
/// <returns>double[], the splits to be used to bucket the input column</returns>
59-
public double[] GetSplits()
60-
{
61-
return (double[])_jvmObject.Invoke("getSplits");
62-
}
63-
54+
public double[] GetSplits() => (double[])_jvmObject.Invoke("getSplits");
55+
6456
/// <summary>
6557
/// Split points for splitting a single column into buckets. To split multiple columns use
6658
/// SetSplitsArray. You cannot use both SetSplits and SetSplitsArray at the same time
@@ -72,20 +64,15 @@ public double[] GetSplits()
7264
/// increasing. Values outside the splits specified will be treated as errors.
7365
/// </param>
7466
/// <returns>New <see cref="Bucketizer"/> object</returns>
75-
public Bucketizer SetSplits(double[] value)
76-
{
77-
return WrapAsBucketizer(_jvmObject.Invoke("setSplits", value));
78-
}
67+
public Bucketizer SetSplits(double[] value) =>
68+
WrapAsBucketizer(_jvmObject.Invoke("setSplits", value));
7969

8070
/// <summary>
8171
/// Gets the splits that were set by SetSplitsArray
8272
/// </summary>
8373
/// <returns>double[][], the splits to be used to bucket the input columns</returns>
84-
public double[][] GetSplitsArray()
85-
{
86-
return (double[][])_jvmObject.Invoke("getSplitsArray");
87-
}
88-
74+
public double[][] GetSplitsArray() => (double[][])_jvmObject.Invoke("getSplitsArray");
75+
8976
/// <summary>
9077
/// Split points fot splitting multiple columns into buckets. To split a single column use
9178
/// SetSplits. You cannot use both SetSplits and SetSplitsArray at the same time.
@@ -97,41 +84,32 @@ public double[][] GetSplitsArray()
9784
/// includes y. The splits should be of length &gt;= 3 and strictly increasing.
9885
/// Values outside the splits specified will be treated as errors.</param>
9986
/// <returns>New <see cref="Bucketizer"/> object</returns>
100-
public Bucketizer SetSplitsArray(double[][] value)
101-
{
102-
return WrapAsBucketizer(_jvmObject.Invoke("setSplitsArray", (object)value));
103-
}
87+
public Bucketizer SetSplitsArray(double[][] value) =>
88+
WrapAsBucketizer(_jvmObject.Invoke("setSplitsArray", (object)value));
10489

10590
/// <summary>
10691
/// Gets the column that the <see cref="Bucketizer"/> should read from and convert into
10792
/// buckets. This would have been set by SetInputCol
10893
/// </summary>
10994
/// <returns>string, the input column</returns>
110-
public string GetInputCol()
111-
{
112-
return (string)_jvmObject.Invoke("getInputCol");
113-
}
114-
95+
public string GetInputCol() => (string)_jvmObject.Invoke("getInputCol");
96+
11597
/// <summary>
11698
/// Sets the column that the <see cref="Bucketizer"/> should read from and convert into
11799
/// buckets
118100
/// </summary>
119101
/// <param name="value">The name of the column to as the source of the buckets</param>
120102
/// <returns>New <see cref="Bucketizer"/> object</returns>
121-
public Bucketizer SetInputCol(string value)
122-
{
123-
return WrapAsBucketizer(_jvmObject.Invoke("setInputCol", value));
124-
}
125-
103+
public Bucketizer SetInputCol(string value) =>
104+
WrapAsBucketizer(_jvmObject.Invoke("setInputCol", value));
105+
126106
/// <summary>
127107
/// Gets the columns that <see cref="Bucketizer"/> should read from and convert into
128108
/// buckets. This is set by SetInputCol
129109
/// </summary>
130110
/// <returns>IEnumerable&lt;string&gt;, list of input columns</returns>
131-
public IEnumerable<string> GetInputCols()
132-
{
133-
return ((string[])(_jvmObject.Invoke("getInputCols"))).ToList();
134-
}
111+
public IEnumerable<string> GetInputCols() =>
112+
((string[])(_jvmObject.Invoke("getInputCols"))).ToList();
135113

136114
/// <summary>
137115
/// Sets the columns that <see cref="Bucketizer"/> should read from and convert into
@@ -142,73 +120,50 @@ public IEnumerable<string> GetInputCols()
142120
/// </summary>
143121
/// <param name="value">List of input columns to use as sources for buckets</param>
144122
/// <returns>New <see cref="Bucketizer"/> object</returns>
145-
public Bucketizer SetInputCols(IEnumerable<string> value)
146-
{
147-
return WrapAsBucketizer(_jvmObject.Invoke("setInputCols", value));
148-
}
149-
123+
public Bucketizer SetInputCols(IEnumerable<string> value) =>
124+
WrapAsBucketizer(_jvmObject.Invoke("setInputCols", value));
125+
150126
/// <summary>
151127
/// Gets the name of the column the output data will be written to. This is set by
152128
/// SetInputCol
153129
/// </summary>
154130
/// <returns>string, the output column</returns>
155-
public string GetOutputCol()
156-
{
157-
return (string)_jvmObject.Invoke("getOutputCol");
158-
}
159-
131+
public string GetOutputCol() => (string)_jvmObject.Invoke("getOutputCol");
132+
160133
/// <summary>
161134
/// The <see cref="Bucketizer"/> will create a new column in the DataFrame, this is the
162135
/// name of the new column.
163136
/// </summary>
164137
/// <param name="value">The name of the new column which contains the bucket ID</param>
165138
/// <returns>New <see cref="Bucketizer"/> object</returns>
166-
public Bucketizer SetOutputCol(string value)
167-
{
168-
return WrapAsBucketizer(_jvmObject.Invoke("setOutputCol", value));
169-
}
139+
public Bucketizer SetOutputCol(string value) =>
140+
WrapAsBucketizer(_jvmObject.Invoke("setOutputCol", value));
170141

171142
/// <summary>
172143
/// The list of columns that the <see cref="Bucketizer"/> will create in the DataFrame.
173144
/// This is set by SetOutputCols
174145
/// </summary>
175146
/// <returns>IEnumerable&lt;string&gt;, list of output columns</returns>
176-
public IEnumerable<string> GetOutputCols()
177-
{
178-
return ((string[])_jvmObject.Invoke("getOutputCols")).ToList();
179-
}
180-
147+
public IEnumerable<string> GetOutputCols() =>
148+
((string[])_jvmObject.Invoke("getOutputCols")).ToList();
149+
181150
/// <summary>
182151
/// The list of columns that the <see cref="Bucketizer"/> will create in the DataFrame.
183152
/// </summary>
184153
/// <param name="value">List of column names which will contain the bucket ID</param>
185154
/// <returns>New <see cref="Bucketizer"/> object</returns>
186-
public Bucketizer SetOutputCols(List<string> value)
187-
{
188-
return WrapAsBucketizer(_jvmObject.Invoke("setOutputCols", value));
189-
}
190-
155+
public Bucketizer SetOutputCols(List<string> value) =>
156+
WrapAsBucketizer(_jvmObject.Invoke("setOutputCols", value));
157+
191158
/// <summary>
192159
/// Loads the <see cref="Bucketizer"/> that was previously saved using Save
193160
/// </summary>
194161
/// <param name="path">The path the previous <see cref="Bucketizer"/> was saved to</param>
195162
/// <returns>New <see cref="Bucketizer"/> object</returns>
196-
public static Bucketizer Load(string path)
197-
{
198-
return WrapAsBucketizer(
163+
public static Bucketizer Load(string path) =>
164+
WrapAsBucketizer(
199165
SparkEnvironment.JvmBridge.CallStaticJavaMethod(
200166
s_bucketizerClassName,"load", path));
201-
}
202-
203-
/// <summary>
204-
/// Saves the <see cref="Bucketizer"/> so that it can be loaded later using Load
205-
/// </summary>
206-
/// <param name="path">The path to save the <see cref="Bucketizer"/> to</param>
207-
/// <returns>New <see cref="Bucketizer"/> object</returns>
208-
public Bucketizer Save(string path)
209-
{
210-
return WrapAsBucketizer(_jvmObject.Invoke("save", path));
211-
}
212167

213168
/// <summary>
214169
/// Executes the <see cref="Bucketizer"/> and transforms the DataFrame to include the new
@@ -218,31 +173,15 @@ public Bucketizer Save(string path)
218173
/// <returns>
219174
/// <see cref="DataFrame"/> containing the original data and the new bucketed columns
220175
/// </returns>
221-
public DataFrame Transform(DataFrame source)
222-
{
223-
return new DataFrame((JvmObjectReference)_jvmObject.Invoke("transform", source));
224-
}
225-
226-
/// <summary>
227-
/// The uid that was used to create the <see cref="Bucketizer"/>. If no UID is passed in
228-
/// when creating the <see cref="Bucketizer"/> then a random UID is created when the
229-
/// <see cref="Bucketizer"/> is created.
230-
/// </summary>
231-
/// <returns>string UID identifying the <see cref="Bucketizer"/></returns>
232-
public string Uid()
233-
{
234-
return (string)_jvmObject.Invoke("uid");
235-
}
176+
public DataFrame Transform(DataFrame source) =>
177+
new DataFrame((JvmObjectReference)_jvmObject.Invoke("transform", source));
236178

237179
/// <summary>
238180
/// How should the <see cref="Bucketizer"/> handle invalid data, choices are "skip",
239181
/// "error" or "keep"
240182
/// </summary>
241183
/// <returns>string showing the way Spark will handle invalid data</returns>
242-
public string GetHandleInvalid()
243-
{
244-
return (string)_jvmObject.Invoke("getHandleInvalid");
245-
}
184+
public string GetHandleInvalid() => (string)_jvmObject.Invoke("getHandleInvalid");
246185

247186
/// <summary>
248187
/// Tells the <see cref="Bucketizer"/> what to do with invalid data.
@@ -251,11 +190,9 @@ public string GetHandleInvalid()
251190
/// </summary>
252191
/// <param name="value">"skip", "error" or "keep"</param>
253192
/// <returns>New <see cref="Bucketizer"/> object</returns>
254-
public Bucketizer SetHandleInvalid(string value)
255-
{
256-
return WrapAsBucketizer(_jvmObject.Invoke("setHandleInvalid", value.ToString()));
257-
}
258-
193+
public Bucketizer SetHandleInvalid(string value) =>
194+
WrapAsBucketizer(_jvmObject.Invoke("setHandleInvalid", value.ToString()));
195+
259196
private static Bucketizer WrapAsBucketizer(object obj) =>
260197
new Bucketizer((JvmObjectReference)obj);
261198
}
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
using System;
2+
using System.Linq;
3+
using System.Reflection;
4+
using Microsoft.Spark.Interop;
5+
using Microsoft.Spark.Interop.Ipc;
6+
7+
namespace Microsoft.Spark.ML.Feature
8+
{
9+
/// <summary>
10+
/// FeatureBase is to share code amongst all of the ML.Feature objects, there are a few
11+
/// interfaces that the Scala code implements across all of the objects. This should help to
12+
/// write the extra objects faster.
13+
/// </summary>
14+
/// <typeparam name="T">
15+
/// The class that implements FeatureBase, this is needed so we can create new objects where
16+
/// spark returns new objects rather than update existing objects.
17+
/// </typeparam>
18+
public class FeatureBase<T> : Identifiable
19+
{
20+
internal readonly JvmObjectReference _jvmObject;
21+
22+
internal FeatureBase(string className)
23+
: this(SparkEnvironment.JvmBridge.CallConstructor(className))
24+
{
25+
}
26+
27+
internal FeatureBase(string className, string uid)
28+
: this(SparkEnvironment.JvmBridge.CallConstructor(className, uid))
29+
{
30+
}
31+
32+
internal FeatureBase(JvmObjectReference jvmObject)
33+
{
34+
_jvmObject = jvmObject;
35+
}
36+
37+
/// <summary>
38+
/// Returns the JVM toString value rather than the .NET ToString default
39+
/// </summary>
40+
/// <returns>JVM toString() value</returns>
41+
public override string ToString() => (string)_jvmObject.Invoke("toString");
42+
43+
/// <summary>
44+
/// The UID that was used to create the object. If no UID is passed in when creating the
45+
/// object then a random UID is created when the object is created.
46+
/// </summary>
47+
/// <returns>string UID identifying the object</returns>
48+
public string Uid() => (string)_jvmObject.Invoke("uid");
49+
50+
/// <summary>
51+
/// Saves the object so that it can be loaded later using Load. Note that these objects
52+
/// can be shared with Scala by Loading or Saving in Scala.
53+
/// </summary>
54+
/// <param name="path">The path to save the object to</param>
55+
/// <returns>New object</returns>
56+
public T Save(string path) =>
57+
WrapAsType((JvmObjectReference)_jvmObject.Invoke("save", path));
58+
59+
private T WrapAsType(JvmObjectReference reference)
60+
{
61+
ConstructorInfo constructor = typeof(T)
62+
.GetConstructors(BindingFlags.NonPublic | BindingFlags.Instance)
63+
.Single(c =>
64+
{
65+
ParameterInfo[] parameters = c.GetParameters();
66+
return (parameters.Length == 1) &&
67+
(parameters[0].ParameterType == typeof(JvmObjectReference));
68+
});
69+
70+
return (T)constructor.Invoke(new object[] {reference});
71+
}
72+
}
73+
}

0 commit comments

Comments
 (0)