Skip to content

Commit d3ae3e1

Browse files
WeichenXu123yanboliang
authored andcommitted
[SPARK-19634][SQL][ML][FOLLOW-UP] Improve interface of dataframe vectorized summarizer
## What changes were proposed in this pull request? Make several improvements in dataframe vectorized summarizer. 1. Make the summarizer return `Vector` type for all metrics (except "count"). It will return "WrappedArray" type before which won't be very convenient. 2. Make `MetricsAggregate` inherit `ImplicitCastInputTypes` trait. So it can check and implicitly cast input values. 3. Add "weight" parameter for all single metric method. 4. Update doc and improve the example code in doc. 5. Simplified test cases. ## How was this patch tested? Test added and simplified. Author: WeichenXu <[email protected]> Closes #19156 from WeichenXu123/improve_vec_summarizer.
1 parent 9c289a5 commit d3ae3e1

File tree

3 files changed

+341
-213
lines changed

3 files changed

+341
-213
lines changed

mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala

Lines changed: 85 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import org.apache.spark.internal.Logging
2424
import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
2525
import org.apache.spark.sql.Column
2626
import org.apache.spark.sql.catalyst.InternalRow
27-
import org.apache.spark.sql.catalyst.expressions.{Expression, UnsafeArrayData}
27+
import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes, UnsafeArrayData}
2828
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, TypedImperativeAggregate}
2929
import org.apache.spark.sql.functions.lit
3030
import org.apache.spark.sql.types._
@@ -41,7 +41,7 @@ sealed abstract class SummaryBuilder {
4141
/**
4242
* Returns an aggregate object that contains the summary of the column with the requested metrics.
4343
* @param featuresCol a column that contains features Vector object.
44-
* @param weightCol a column that contains weight value.
44+
* @param weightCol a column that contains weight value. Default weight is 1.0.
4545
* @return an aggregate column that contains the statistics. The exact content of this
4646
* structure is determined during the creation of the builder.
4747
*/
@@ -50,6 +50,7 @@ sealed abstract class SummaryBuilder {
5050

5151
@Since("2.3.0")
5252
def summary(featuresCol: Column): Column = summary(featuresCol, lit(1.0))
53+
5354
}
5455

5556
/**
@@ -60,15 +61,18 @@ sealed abstract class SummaryBuilder {
6061
* This class lets users pick the statistics they would like to extract for a given column. Here is
6162
* an example in Scala:
6263
* {{{
63-
* val dataframe = ... // Some dataframe containing a feature column
64-
* val allStats = dataframe.select(Summarizer.metrics("min", "max").summary($"features"))
65-
* val Row(Row(min_, max_)) = allStats.first()
64+
* import org.apache.spark.ml.linalg._
65+
* import org.apache.spark.sql.Row
66+
* val dataframe = ... // Some dataframe containing a feature column and a weight column
67+
* val multiStatsDF = dataframe.select(
68+
* Summarizer.metrics("min", "max", "count").summary($"features", $"weight")
69+
* val Row(Row(minVec, maxVec, count)) = multiStatsDF.first()
6670
* }}}
6771
*
6872
* If one wants to get a single metric, shortcuts are also available:
6973
* {{{
7074
* val meanDF = dataframe.select(Summarizer.mean($"features"))
71-
* val Row(mean_) = meanDF.first()
75+
* val Row(meanVec) = meanDF.first()
7276
* }}}
7377
*
7478
* Note: Currently, the performance of this interface is about 2x~3x slower then using the RDD
@@ -94,46 +98,87 @@ object Summarizer extends Logging {
9498
* - min: the minimum for each coefficient.
9599
* - normL2: the Euclidian norm for each coefficient.
96100
* - normL1: the L1 norm of each coefficient (sum of the absolute values).
97-
* @param firstMetric the metric being provided
98-
* @param metrics additional metrics that can be provided.
101+
* @param metrics metrics that can be provided.
99102
* @return a builder.
100103
* @throws IllegalArgumentException if one of the metric names is not understood.
101104
*
102105
* Note: Currently, the performance of this interface is about 2x~3x slower then using the RDD
103106
* interface.
104107
*/
105108
@Since("2.3.0")
106-
def metrics(firstMetric: String, metrics: String*): SummaryBuilder = {
107-
val (typedMetrics, computeMetrics) = getRelevantMetrics(Seq(firstMetric) ++ metrics)
109+
@scala.annotation.varargs
110+
def metrics(metrics: String*): SummaryBuilder = {
111+
require(metrics.size >= 1, "Should include at least one metric")
112+
val (typedMetrics, computeMetrics) = getRelevantMetrics(metrics)
108113
new SummaryBuilderImpl(typedMetrics, computeMetrics)
109114
}
110115

111116
@Since("2.3.0")
112-
def mean(col: Column): Column = getSingleMetric(col, "mean")
117+
def mean(col: Column, weightCol: Column): Column = {
118+
getSingleMetric(col, weightCol, "mean")
119+
}
120+
121+
@Since("2.3.0")
122+
def mean(col: Column): Column = mean(col, lit(1.0))
123+
124+
@Since("2.3.0")
125+
def variance(col: Column, weightCol: Column): Column = {
126+
getSingleMetric(col, weightCol, "variance")
127+
}
128+
129+
@Since("2.3.0")
130+
def variance(col: Column): Column = variance(col, lit(1.0))
131+
132+
@Since("2.3.0")
133+
def count(col: Column, weightCol: Column): Column = {
134+
getSingleMetric(col, weightCol, "count")
135+
}
136+
137+
@Since("2.3.0")
138+
def count(col: Column): Column = count(col, lit(1.0))
113139

114140
@Since("2.3.0")
115-
def variance(col: Column): Column = getSingleMetric(col, "variance")
141+
def numNonZeros(col: Column, weightCol: Column): Column = {
142+
getSingleMetric(col, weightCol, "numNonZeros")
143+
}
144+
145+
@Since("2.3.0")
146+
def numNonZeros(col: Column): Column = numNonZeros(col, lit(1.0))
147+
148+
@Since("2.3.0")
149+
def max(col: Column, weightCol: Column): Column = {
150+
getSingleMetric(col, weightCol, "max")
151+
}
152+
153+
@Since("2.3.0")
154+
def max(col: Column): Column = max(col, lit(1.0))
116155

117156
@Since("2.3.0")
118-
def count(col: Column): Column = getSingleMetric(col, "count")
157+
def min(col: Column, weightCol: Column): Column = {
158+
getSingleMetric(col, weightCol, "min")
159+
}
119160

120161
@Since("2.3.0")
121-
def numNonZeros(col: Column): Column = getSingleMetric(col, "numNonZeros")
162+
def min(col: Column): Column = min(col, lit(1.0))
122163

123164
@Since("2.3.0")
124-
def max(col: Column): Column = getSingleMetric(col, "max")
165+
def normL1(col: Column, weightCol: Column): Column = {
166+
getSingleMetric(col, weightCol, "normL1")
167+
}
125168

126169
@Since("2.3.0")
127-
def min(col: Column): Column = getSingleMetric(col, "min")
170+
def normL1(col: Column): Column = normL1(col, lit(1.0))
128171

129172
@Since("2.3.0")
130-
def normL1(col: Column): Column = getSingleMetric(col, "normL1")
173+
def normL2(col: Column, weightCol: Column): Column = {
174+
getSingleMetric(col, weightCol, "normL2")
175+
}
131176

132177
@Since("2.3.0")
133-
def normL2(col: Column): Column = getSingleMetric(col, "normL2")
178+
def normL2(col: Column): Column = normL2(col, lit(1.0))
134179

135-
private def getSingleMetric(col: Column, metric: String): Column = {
136-
val c1 = metrics(metric).summary(col)
180+
private def getSingleMetric(col: Column, weightCol: Column, metric: String): Column = {
181+
val c1 = metrics(metric).summary(col, weightCol)
137182
c1.getField(metric).as(s"$metric($col)")
138183
}
139184
}
@@ -187,8 +232,7 @@ private[ml] object SummaryBuilderImpl extends Logging {
187232
StructType(fields)
188233
}
189234

190-
private val arrayDType = ArrayType(DoubleType, containsNull = false)
191-
private val arrayLType = ArrayType(LongType, containsNull = false)
235+
private val vectorUDT = new VectorUDT
192236

193237
/**
194238
* All the metrics that can be currently computed by Spark for vectors.
@@ -197,14 +241,14 @@ private[ml] object SummaryBuilderImpl extends Logging {
197241
* metrics that need to de computed internally to get the final result.
198242
*/
199243
private val allMetrics: Seq[(String, Metric, DataType, Seq[ComputeMetric])] = Seq(
200-
("mean", Mean, arrayDType, Seq(ComputeMean, ComputeWeightSum)),
201-
("variance", Variance, arrayDType, Seq(ComputeWeightSum, ComputeMean, ComputeM2n)),
244+
("mean", Mean, vectorUDT, Seq(ComputeMean, ComputeWeightSum)),
245+
("variance", Variance, vectorUDT, Seq(ComputeWeightSum, ComputeMean, ComputeM2n)),
202246
("count", Count, LongType, Seq()),
203-
("numNonZeros", NumNonZeros, arrayLType, Seq(ComputeNNZ)),
204-
("max", Max, arrayDType, Seq(ComputeMax, ComputeNNZ)),
205-
("min", Min, arrayDType, Seq(ComputeMin, ComputeNNZ)),
206-
("normL2", NormL2, arrayDType, Seq(ComputeM2)),
207-
("normL1", NormL1, arrayDType, Seq(ComputeL1))
247+
("numNonZeros", NumNonZeros, vectorUDT, Seq(ComputeNNZ)),
248+
("max", Max, vectorUDT, Seq(ComputeMax, ComputeNNZ)),
249+
("min", Min, vectorUDT, Seq(ComputeMin, ComputeNNZ)),
250+
("normL2", NormL2, vectorUDT, Seq(ComputeM2)),
251+
("normL1", NormL1, vectorUDT, Seq(ComputeL1))
208252
)
209253

210254
/**
@@ -527,27 +571,28 @@ private[ml] object SummaryBuilderImpl extends Logging {
527571
weightExpr: Expression,
528572
mutableAggBufferOffset: Int,
529573
inputAggBufferOffset: Int)
530-
extends TypedImperativeAggregate[SummarizerBuffer] {
574+
extends TypedImperativeAggregate[SummarizerBuffer] with ImplicitCastInputTypes {
531575

532-
override def eval(state: SummarizerBuffer): InternalRow = {
576+
override def eval(state: SummarizerBuffer): Any = {
533577
val metrics = requestedMetrics.map {
534-
case Mean => UnsafeArrayData.fromPrimitiveArray(state.mean.toArray)
535-
case Variance => UnsafeArrayData.fromPrimitiveArray(state.variance.toArray)
578+
case Mean => vectorUDT.serialize(state.mean)
579+
case Variance => vectorUDT.serialize(state.variance)
536580
case Count => state.count
537-
case NumNonZeros => UnsafeArrayData.fromPrimitiveArray(
538-
state.numNonzeros.toArray.map(_.toLong))
539-
case Max => UnsafeArrayData.fromPrimitiveArray(state.max.toArray)
540-
case Min => UnsafeArrayData.fromPrimitiveArray(state.min.toArray)
541-
case NormL2 => UnsafeArrayData.fromPrimitiveArray(state.normL2.toArray)
542-
case NormL1 => UnsafeArrayData.fromPrimitiveArray(state.normL1.toArray)
581+
case NumNonZeros => vectorUDT.serialize(state.numNonzeros)
582+
case Max => vectorUDT.serialize(state.max)
583+
case Min => vectorUDT.serialize(state.min)
584+
case NormL2 => vectorUDT.serialize(state.normL2)
585+
case NormL1 => vectorUDT.serialize(state.normL1)
543586
}
544587
InternalRow.apply(metrics: _*)
545588
}
546589

590+
override def inputTypes: Seq[DataType] = vectorUDT :: DoubleType :: Nil
591+
547592
override def children: Seq[Expression] = featuresExpr :: weightExpr :: Nil
548593

549594
override def update(state: SummarizerBuffer, row: InternalRow): SummarizerBuffer = {
550-
val features = udt.deserialize(featuresExpr.eval(row))
595+
val features = vectorUDT.deserialize(featuresExpr.eval(row))
551596
val weight = weightExpr.eval(row).asInstanceOf[Double]
552597
state.add(features, weight)
553598
state
@@ -591,7 +636,4 @@ private[ml] object SummaryBuilderImpl extends Logging {
591636
override def prettyName: String = "aggregate_metrics"
592637

593638
}
594-
595-
private[this] val udt = new VectorUDT
596-
597639
}
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.stat;
19+
20+
import java.io.IOException;
21+
import java.util.ArrayList;
22+
import java.util.List;
23+
24+
import org.junit.Test;
25+
import static org.junit.Assert.assertEquals;
26+
import static org.junit.Assert.assertArrayEquals;
27+
28+
import org.apache.spark.SharedSparkSession;
29+
import org.apache.spark.sql.Row;
30+
import org.apache.spark.sql.Dataset;
31+
import static org.apache.spark.sql.functions.col;
32+
import org.apache.spark.ml.feature.LabeledPoint;
33+
import org.apache.spark.ml.linalg.Vector;
34+
import org.apache.spark.ml.linalg.Vectors;
35+
36+
public class JavaSummarizerSuite extends SharedSparkSession {
37+
38+
private transient Dataset<Row> dataset;
39+
40+
@Override
41+
public void setUp() throws IOException {
42+
super.setUp();
43+
List<LabeledPoint> points = new ArrayList<LabeledPoint>();
44+
points.add(new LabeledPoint(0.0, Vectors.dense(1.0, 2.0)));
45+
points.add(new LabeledPoint(0.0, Vectors.dense(3.0, 4.0)));
46+
47+
dataset = spark.createDataFrame(jsc.parallelize(points, 2), LabeledPoint.class);
48+
}
49+
50+
@Test
51+
public void testSummarizer() {
52+
dataset.select(col("features"));
53+
Row result = dataset
54+
.select(Summarizer.metrics("mean", "max", "count").summary(col("features")))
55+
.first().getStruct(0);
56+
Vector meanVec = result.getAs("mean");
57+
Vector maxVec = result.getAs("max");
58+
long count = result.getAs("count");
59+
60+
assertEquals(2L, count);
61+
assertArrayEquals(new double[]{2.0, 3.0}, meanVec.toArray(), 0.0);
62+
assertArrayEquals(new double[]{3.0, 4.0}, maxVec.toArray(), 0.0);
63+
}
64+
}

0 commit comments

Comments
 (0)