|
1 | 1 | /* |
2 | 2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | | - * or more contributor license agreements. See the NOTICE file |
| 3 | + * or more contributor license agreements. See the NOTICE file |
4 | 4 | * distributed with this work for additional information |
5 | | - * regarding copyright ownership. The ASF licenses this file |
6 | | - * to you under the Apache License, Version 2.0 (the |
7 | | - * "License"); you may not use this file except in compliance |
8 | | - * with the License. You may obtain a copy of the License at |
| 5 | + * regarding copyright ownership. The ASF licenses this file |
| 6 | + * to you under the Apache License, Version 2.0 (the "License"); |
| 7 | + * you may not use this file except in compliance with the License. |
| 8 | + * You may obtain a copy of the License at |
9 | 9 | * |
10 | 10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | 11 | * |
|
24 | 24 | import org.apache.spark.ml.linalg.Vector; |
25 | 25 | import org.apache.spark.ml.linalg.VectorUDT; |
26 | 26 | import org.apache.spark.ml.linalg.Vectors; |
27 | | -import org.apache.spark.sql.Dataset; |
28 | | -import org.apache.spark.sql.Row; |
29 | | -import org.apache.spark.sql.RowFactory; |
30 | | -import org.apache.spark.sql.SparkSession; |
31 | | -import org.apache.spark.sql.types.DataTypes; |
32 | | -import org.apache.spark.sql.types.StructField; |
33 | | -import org.apache.spark.sql.types.StructType; |
| 27 | +import org.apache.spark.sql.*; |
| 28 | +import org.apache.spark.sql.types.*; |
34 | 29 | import org.apache.wayang.basic.data.Tuple2; |
35 | 30 | import org.apache.wayang.basic.operators.KMeansOperator; |
36 | 31 | import org.apache.wayang.core.optimizer.OptimizationContext; |
|
49 | 44 |
|
50 | 45 | public class SparkKMeansOperator extends KMeansOperator implements SparkExecutionOperator { |
51 | 46 |
|
52 | | - private static final StructType schema = DataTypes.createStructType( |
| 47 | + private static final StructType SCHEMA = DataTypes.createStructType( |
53 | 48 | new StructField[]{ |
54 | | - DataTypes.createStructField(Attr.FEATURES, new VectorUDT(), false) |
| 49 | + DataTypes.createStructField("features", new VectorUDT(), false) |
55 | 50 | } |
56 | 51 | ); |
57 | 52 |
|
58 | | - private static Dataset<Row> data2Row(JavaRDD<double[]> inputRdd) { |
59 | | - final JavaRDD<Row> rowRdd = inputRdd.map(e -> RowFactory.create(Vectors.dense(e))); |
60 | | - return SparkSession.builder().getOrCreate().createDataFrame(rowRdd, schema); |
| 53 | + private static Dataset<Row> convertToDataFrame(JavaRDD<double[]> inputRdd) { |
| 54 | + JavaRDD<Row> rowRdd = inputRdd.map(e -> RowFactory.create(Vectors.dense(e))); |
| 55 | + return SparkSession.builder().getOrCreate().createDataFrame(rowRdd, SCHEMA); |
61 | 56 | } |
62 | 57 |
|
63 | 58 | public SparkKMeansOperator(int k) { |
@@ -87,17 +82,17 @@ public Tuple<Collection<ExecutionLineageNode>, Collection<ChannelInstance>> eval |
87 | 82 | assert inputs.length == this.getNumInputs(); |
88 | 83 | assert outputs.length == this.getNumOutputs(); |
89 | 84 |
|
90 | | - final RddChannel.Instance input = (RddChannel.Instance) inputs[0]; |
91 | | - final CollectionChannel.Instance output = (CollectionChannel.Instance) outputs[0]; |
| 85 | + RddChannel.Instance input = (RddChannel.Instance) inputs[0]; |
| 86 | + CollectionChannel.Instance output = (CollectionChannel.Instance) outputs[0]; |
92 | 87 |
|
93 | | - final JavaRDD<double[]> inputRdd = input.provideRdd(); |
94 | | - final Dataset<Row> df = data2Row(inputRdd); |
95 | | - final KMeansModel model = new KMeans() |
| 88 | + JavaRDD<double[]> inputRdd = input.provideRdd(); |
| 89 | + Dataset<Row> df = convertToDataFrame(inputRdd); |
| 90 | + KMeansModel model = new KMeans() |
96 | 91 | .setK(this.k) |
97 | | - .setFeaturesCol(Attr.FEATURES) |
98 | | - .setPredictionCol(Attr.PREDICTION) |
| 92 | + .setFeaturesCol("features") |
| 93 | + .setPredictionCol("prediction") |
99 | 94 | .fit(df); |
100 | | - final Model outputModel = new Model(model); |
| 95 | + Model outputModel = new Model(model); |
101 | 96 | output.accept(Collections.singletonList(outputModel)); |
102 | 97 |
|
103 | 98 | return ExecutionOperator.modelLazyExecution(inputs, outputs, operatorContext); |
@@ -127,18 +122,18 @@ public double[][] getClusterCenters() { |
127 | 122 |
|
128 | 123 | @Override |
129 | 124 | public JavaRDD<Tuple2<double[], Integer>> transform(JavaRDD<double[]> input) { |
130 | | - final Dataset<Row> df = data2Row(input); |
131 | | - final Dataset<Row> transform = model.transform(df); |
132 | | - return transform.toJavaRDD() |
133 | | - .map(row -> new Tuple2<>(row.<Vector>getAs(Attr.FEATURES).toArray(), row.<Integer>getAs(Attr.PREDICTION))); |
| 125 | + Dataset<Row> df = convertToDataFrame(input); |
| 126 | + Dataset<Row> transformed = model.transform(df); |
| 127 | + return transformed.toJavaRDD() |
| 128 | + .map(row -> new Tuple2<>(row.<Vector>getAs("features").toArray(), row.<Integer>getAs("prediction"))); |
134 | 129 | } |
135 | 130 |
|
136 | 131 | @Override |
137 | 132 | public JavaRDD<Integer> predict(JavaRDD<double[]> input) { |
138 | | - final Dataset<Row> df = data2Row(input); |
139 | | - final Dataset<Row> transform = model.transform(df); |
140 | | - return transform.toJavaRDD() |
141 | | - .map(row -> row.<Integer>getAs(Attr.PREDICTION)); |
| 133 | + Dataset<Row> df = convertToDataFrame(input); |
| 134 | + Dataset<Row> transformed = model.transform(df); |
| 135 | + return transformed.toJavaRDD() |
| 136 | + .map(row -> row.<Integer>getAs("prediction")); |
142 | 137 | } |
143 | 138 | } |
144 | 139 | } |
0 commit comments