Skip to content

Commit 1d5736f

Browse files
authored
Merge pull request #508 from Manas-Dikshit/main
[SPARK] Implement Correct fit() and transform() in SparkKMeansOperator
2 parents 14ffb58 + 5509aac commit 1d5736f

File tree

1 file changed

+28
-33
lines changed

1 file changed

+28
-33
lines changed

wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/ml/SparkKMeansOperator.java

Lines changed: 28 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
/*
22
* Licensed to the Apache Software Foundation (ASF) under one
3-
* or more contributor license agreements. See the NOTICE file
3+
* or more contributor license agreements. See the NOTICE file
44
* distributed with this work for additional information
5-
* regarding copyright ownership. The ASF licenses this file
6-
* to you under the Apache License, Version 2.0 (the
7-
* "License"); you may not use this file except in compliance
8-
* with the License. You may obtain a copy of the License at
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the "License");
7+
* you may not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
99
*
1010
* http://www.apache.org/licenses/LICENSE-2.0
1111
*
@@ -24,13 +24,8 @@
2424
import org.apache.spark.ml.linalg.Vector;
2525
import org.apache.spark.ml.linalg.VectorUDT;
2626
import org.apache.spark.ml.linalg.Vectors;
27-
import org.apache.spark.sql.Dataset;
28-
import org.apache.spark.sql.Row;
29-
import org.apache.spark.sql.RowFactory;
30-
import org.apache.spark.sql.SparkSession;
31-
import org.apache.spark.sql.types.DataTypes;
32-
import org.apache.spark.sql.types.StructField;
33-
import org.apache.spark.sql.types.StructType;
27+
import org.apache.spark.sql.*;
28+
import org.apache.spark.sql.types.*;
3429
import org.apache.wayang.basic.data.Tuple2;
3530
import org.apache.wayang.basic.operators.KMeansOperator;
3631
import org.apache.wayang.core.optimizer.OptimizationContext;
@@ -49,15 +44,15 @@
4944

5045
public class SparkKMeansOperator extends KMeansOperator implements SparkExecutionOperator {
5146

52-
private static final StructType schema = DataTypes.createStructType(
47+
private static final StructType SCHEMA = DataTypes.createStructType(
5348
new StructField[]{
54-
DataTypes.createStructField(Attr.FEATURES, new VectorUDT(), false)
49+
DataTypes.createStructField("features", new VectorUDT(), false)
5550
}
5651
);
5752

58-
private static Dataset<Row> data2Row(JavaRDD<double[]> inputRdd) {
59-
final JavaRDD<Row> rowRdd = inputRdd.map(e -> RowFactory.create(Vectors.dense(e)));
60-
return SparkSession.builder().getOrCreate().createDataFrame(rowRdd, schema);
53+
private static Dataset<Row> convertToDataFrame(JavaRDD<double[]> inputRdd) {
54+
JavaRDD<Row> rowRdd = inputRdd.map(e -> RowFactory.create(Vectors.dense(e)));
55+
return SparkSession.builder().getOrCreate().createDataFrame(rowRdd, SCHEMA);
6156
}
6257

6358
public SparkKMeansOperator(int k) {
@@ -87,17 +82,17 @@ public Tuple<Collection<ExecutionLineageNode>, Collection<ChannelInstance>> eval
8782
assert inputs.length == this.getNumInputs();
8883
assert outputs.length == this.getNumOutputs();
8984

90-
final RddChannel.Instance input = (RddChannel.Instance) inputs[0];
91-
final CollectionChannel.Instance output = (CollectionChannel.Instance) outputs[0];
85+
RddChannel.Instance input = (RddChannel.Instance) inputs[0];
86+
CollectionChannel.Instance output = (CollectionChannel.Instance) outputs[0];
9287

93-
final JavaRDD<double[]> inputRdd = input.provideRdd();
94-
final Dataset<Row> df = data2Row(inputRdd);
95-
final KMeansModel model = new KMeans()
88+
JavaRDD<double[]> inputRdd = input.provideRdd();
89+
Dataset<Row> df = convertToDataFrame(inputRdd);
90+
KMeansModel model = new KMeans()
9691
.setK(this.k)
97-
.setFeaturesCol(Attr.FEATURES)
98-
.setPredictionCol(Attr.PREDICTION)
92+
.setFeaturesCol("features")
93+
.setPredictionCol("prediction")
9994
.fit(df);
100-
final Model outputModel = new Model(model);
95+
Model outputModel = new Model(model);
10196
output.accept(Collections.singletonList(outputModel));
10297

10398
return ExecutionOperator.modelLazyExecution(inputs, outputs, operatorContext);
@@ -127,18 +122,18 @@ public double[][] getClusterCenters() {
127122

128123
@Override
129124
public JavaRDD<Tuple2<double[], Integer>> transform(JavaRDD<double[]> input) {
130-
final Dataset<Row> df = data2Row(input);
131-
final Dataset<Row> transform = model.transform(df);
132-
return transform.toJavaRDD()
133-
.map(row -> new Tuple2<>(row.<Vector>getAs(Attr.FEATURES).toArray(), row.<Integer>getAs(Attr.PREDICTION)));
125+
Dataset<Row> df = convertToDataFrame(input);
126+
Dataset<Row> transformed = model.transform(df);
127+
return transformed.toJavaRDD()
128+
.map(row -> new Tuple2<>(row.<Vector>getAs("features").toArray(), row.<Integer>getAs("prediction")));
134129
}
135130

136131
@Override
137132
public JavaRDD<Integer> predict(JavaRDD<double[]> input) {
138-
final Dataset<Row> df = data2Row(input);
139-
final Dataset<Row> transform = model.transform(df);
140-
return transform.toJavaRDD()
141-
.map(row -> row.<Integer>getAs(Attr.PREDICTION));
133+
Dataset<Row> df = convertToDataFrame(input);
134+
Dataset<Row> transformed = model.transform(df);
135+
return transformed.toJavaRDD()
136+
.map(row -> row.<Integer>getAs("prediction"));
142137
}
143138
}
144139
}

0 commit comments

Comments
 (0)