@@ -24,7 +24,7 @@ import org.apache.spark.internal.Logging
24
24
import org .apache .spark .ml .linalg .{Vector , Vectors , VectorUDT }
25
25
import org .apache .spark .sql .Column
26
26
import org .apache .spark .sql .catalyst .InternalRow
27
- import org .apache .spark .sql .catalyst .expressions .{Expression , UnsafeArrayData }
27
+ import org .apache .spark .sql .catalyst .expressions .{Expression , ImplicitCastInputTypes , UnsafeArrayData }
28
28
import org .apache .spark .sql .catalyst .expressions .aggregate .{AggregateExpression , Complete , TypedImperativeAggregate }
29
29
import org .apache .spark .sql .functions .lit
30
30
import org .apache .spark .sql .types ._
@@ -41,7 +41,7 @@ sealed abstract class SummaryBuilder {
41
41
/**
42
42
* Returns an aggregate object that contains the summary of the column with the requested metrics.
43
43
* @param featuresCol a column that contains features Vector object.
44
- * @param weightCol a column that contains weight value.
44
+ * @param weightCol a column that contains weight value. Default weight is 1.0.
45
45
* @return an aggregate column that contains the statistics. The exact content of this
46
46
* structure is determined during the creation of the builder.
47
47
*/
@@ -50,6 +50,7 @@ sealed abstract class SummaryBuilder {
50
50
51
51
@ Since (" 2.3.0" )
52
52
def summary (featuresCol : Column ): Column = summary(featuresCol, lit(1.0 ))
53
+
53
54
}
54
55
55
56
/**
@@ -60,15 +61,18 @@ sealed abstract class SummaryBuilder {
60
61
* This class lets users pick the statistics they would like to extract for a given column. Here is
61
62
* an example in Scala:
62
63
* {{{
63
- * val dataframe = ... // Some dataframe containing a feature column
64
- * val allStats = dataframe.select(Summarizer.metrics("min", "max").summary($"features"))
65
- * val Row(Row(min_, max_)) = allStats.first()
64
+ * import org.apache.spark.ml.linalg._
65
+ * import org.apache.spark.sql.Row
66
+ * val dataframe = ... // Some dataframe containing a feature column and a weight column
67
+ * val multiStatsDF = dataframe.select(
68
+ * Summarizer.metrics("min", "max", "count").summary($"features", $"weight")
69
+ * val Row(Row(minVec, maxVec, count)) = multiStatsDF.first()
66
70
* }}}
67
71
*
68
72
* If one wants to get a single metric, shortcuts are also available:
69
73
* {{{
70
74
* val meanDF = dataframe.select(Summarizer.mean($"features"))
71
- * val Row(mean_ ) = meanDF.first()
75
+ * val Row(meanVec ) = meanDF.first()
72
76
* }}}
73
77
*
74
78
* Note: Currently, the performance of this interface is about 2x~3x slower then using the RDD
@@ -94,46 +98,87 @@ object Summarizer extends Logging {
94
98
* - min: the minimum for each coefficient.
95
99
* - normL2: the Euclidian norm for each coefficient.
96
100
* - normL1: the L1 norm of each coefficient (sum of the absolute values).
97
- * @param firstMetric the metric being provided
98
- * @param metrics additional metrics that can be provided.
101
+ * @param metrics metrics that can be provided.
99
102
* @return a builder.
100
103
* @throws IllegalArgumentException if one of the metric names is not understood.
101
104
*
102
105
* Note: Currently, the performance of this interface is about 2x~3x slower then using the RDD
103
106
* interface.
104
107
*/
105
108
@ Since (" 2.3.0" )
106
- def metrics (firstMetric : String , metrics : String * ): SummaryBuilder = {
107
- val (typedMetrics, computeMetrics) = getRelevantMetrics(Seq (firstMetric) ++ metrics)
109
+ @ scala.annotation.varargs
110
+ def metrics (metrics : String * ): SummaryBuilder = {
111
+ require(metrics.size >= 1 , " Should include at least one metric" )
112
+ val (typedMetrics, computeMetrics) = getRelevantMetrics(metrics)
108
113
new SummaryBuilderImpl (typedMetrics, computeMetrics)
109
114
}
110
115
111
116
@ Since (" 2.3.0" )
112
- def mean (col : Column ): Column = getSingleMetric(col, " mean" )
117
+ def mean (col : Column , weightCol : Column ): Column = {
118
+ getSingleMetric(col, weightCol, " mean" )
119
+ }
120
+
121
+ @ Since (" 2.3.0" )
122
+ def mean (col : Column ): Column = mean(col, lit(1.0 ))
123
+
124
+ @ Since (" 2.3.0" )
125
+ def variance (col : Column , weightCol : Column ): Column = {
126
+ getSingleMetric(col, weightCol, " variance" )
127
+ }
128
+
129
+ @ Since (" 2.3.0" )
130
+ def variance (col : Column ): Column = variance(col, lit(1.0 ))
131
+
132
+ @ Since (" 2.3.0" )
133
+ def count (col : Column , weightCol : Column ): Column = {
134
+ getSingleMetric(col, weightCol, " count" )
135
+ }
136
+
137
+ @ Since (" 2.3.0" )
138
+ def count (col : Column ): Column = count(col, lit(1.0 ))
113
139
114
140
@ Since (" 2.3.0" )
115
- def variance (col : Column ): Column = getSingleMetric(col, " variance" )
141
+ def numNonZeros (col : Column , weightCol : Column ): Column = {
142
+ getSingleMetric(col, weightCol, " numNonZeros" )
143
+ }
144
+
145
+ @ Since (" 2.3.0" )
146
+ def numNonZeros (col : Column ): Column = numNonZeros(col, lit(1.0 ))
147
+
148
+ @ Since (" 2.3.0" )
149
+ def max (col : Column , weightCol : Column ): Column = {
150
+ getSingleMetric(col, weightCol, " max" )
151
+ }
152
+
153
+ @ Since (" 2.3.0" )
154
+ def max (col : Column ): Column = max(col, lit(1.0 ))
116
155
117
156
@ Since (" 2.3.0" )
118
- def count (col : Column ): Column = getSingleMetric(col, " count" )
157
+ def min (col : Column , weightCol : Column ): Column = {
158
+ getSingleMetric(col, weightCol, " min" )
159
+ }
119
160
120
161
@ Since (" 2.3.0" )
121
- def numNonZeros (col : Column ): Column = getSingleMetric (col, " numNonZeros " )
162
+ def min (col : Column ): Column = min (col, lit( 1.0 ) )
122
163
123
164
@ Since (" 2.3.0" )
124
- def max (col : Column ): Column = getSingleMetric(col, " max" )
165
+ def normL1 (col : Column , weightCol : Column ): Column = {
166
+ getSingleMetric(col, weightCol, " normL1" )
167
+ }
125
168
126
169
@ Since (" 2.3.0" )
127
- def min (col : Column ): Column = getSingleMetric (col, " min " )
170
+ def normL1 (col : Column ): Column = normL1 (col, lit( 1.0 ) )
128
171
129
172
@ Since (" 2.3.0" )
130
- def normL1 (col : Column ): Column = getSingleMetric(col, " normL1" )
173
+ def normL2 (col : Column , weightCol : Column ): Column = {
174
+ getSingleMetric(col, weightCol, " normL2" )
175
+ }
131
176
132
177
@ Since (" 2.3.0" )
133
- def normL2 (col : Column ): Column = getSingleMetric (col, " normL2 " )
178
+ def normL2 (col : Column ): Column = normL2 (col, lit( 1.0 ) )
134
179
135
- private def getSingleMetric (col : Column , metric : String ): Column = {
136
- val c1 = metrics(metric).summary(col)
180
+ private def getSingleMetric (col : Column , weightCol : Column , metric : String ): Column = {
181
+ val c1 = metrics(metric).summary(col, weightCol )
137
182
c1.getField(metric).as(s " $metric( $col) " )
138
183
}
139
184
}
@@ -187,8 +232,7 @@ private[ml] object SummaryBuilderImpl extends Logging {
187
232
StructType (fields)
188
233
}
189
234
190
- private val arrayDType = ArrayType (DoubleType , containsNull = false )
191
- private val arrayLType = ArrayType (LongType , containsNull = false )
235
+ private val vectorUDT = new VectorUDT
192
236
193
237
/**
194
238
* All the metrics that can be currently computed by Spark for vectors.
@@ -197,14 +241,14 @@ private[ml] object SummaryBuilderImpl extends Logging {
197
241
* metrics that need to de computed internally to get the final result.
198
242
*/
199
243
private val allMetrics : Seq [(String , Metric , DataType , Seq [ComputeMetric ])] = Seq (
200
- (" mean" , Mean , arrayDType , Seq (ComputeMean , ComputeWeightSum )),
201
- (" variance" , Variance , arrayDType , Seq (ComputeWeightSum , ComputeMean , ComputeM2n )),
244
+ (" mean" , Mean , vectorUDT , Seq (ComputeMean , ComputeWeightSum )),
245
+ (" variance" , Variance , vectorUDT , Seq (ComputeWeightSum , ComputeMean , ComputeM2n )),
202
246
(" count" , Count , LongType , Seq ()),
203
- (" numNonZeros" , NumNonZeros , arrayLType , Seq (ComputeNNZ )),
204
- (" max" , Max , arrayDType , Seq (ComputeMax , ComputeNNZ )),
205
- (" min" , Min , arrayDType , Seq (ComputeMin , ComputeNNZ )),
206
- (" normL2" , NormL2 , arrayDType , Seq (ComputeM2 )),
207
- (" normL1" , NormL1 , arrayDType , Seq (ComputeL1 ))
247
+ (" numNonZeros" , NumNonZeros , vectorUDT , Seq (ComputeNNZ )),
248
+ (" max" , Max , vectorUDT , Seq (ComputeMax , ComputeNNZ )),
249
+ (" min" , Min , vectorUDT , Seq (ComputeMin , ComputeNNZ )),
250
+ (" normL2" , NormL2 , vectorUDT , Seq (ComputeM2 )),
251
+ (" normL1" , NormL1 , vectorUDT , Seq (ComputeL1 ))
208
252
)
209
253
210
254
/**
@@ -527,27 +571,28 @@ private[ml] object SummaryBuilderImpl extends Logging {
527
571
weightExpr : Expression ,
528
572
mutableAggBufferOffset : Int ,
529
573
inputAggBufferOffset : Int )
530
- extends TypedImperativeAggregate [SummarizerBuffer ] {
574
+ extends TypedImperativeAggregate [SummarizerBuffer ] with ImplicitCastInputTypes {
531
575
532
- override def eval (state : SummarizerBuffer ): InternalRow = {
576
+ override def eval (state : SummarizerBuffer ): Any = {
533
577
val metrics = requestedMetrics.map {
534
- case Mean => UnsafeArrayData .fromPrimitiveArray (state.mean.toArray )
535
- case Variance => UnsafeArrayData .fromPrimitiveArray (state.variance.toArray )
578
+ case Mean => vectorUDT.serialize (state.mean)
579
+ case Variance => vectorUDT.serialize (state.variance)
536
580
case Count => state.count
537
- case NumNonZeros => UnsafeArrayData .fromPrimitiveArray(
538
- state.numNonzeros.toArray.map(_.toLong))
539
- case Max => UnsafeArrayData .fromPrimitiveArray(state.max.toArray)
540
- case Min => UnsafeArrayData .fromPrimitiveArray(state.min.toArray)
541
- case NormL2 => UnsafeArrayData .fromPrimitiveArray(state.normL2.toArray)
542
- case NormL1 => UnsafeArrayData .fromPrimitiveArray(state.normL1.toArray)
581
+ case NumNonZeros => vectorUDT.serialize(state.numNonzeros)
582
+ case Max => vectorUDT.serialize(state.max)
583
+ case Min => vectorUDT.serialize(state.min)
584
+ case NormL2 => vectorUDT.serialize(state.normL2)
585
+ case NormL1 => vectorUDT.serialize(state.normL1)
543
586
}
544
587
InternalRow .apply(metrics : _* )
545
588
}
546
589
590
+ override def inputTypes : Seq [DataType ] = vectorUDT :: DoubleType :: Nil
591
+
547
592
override def children : Seq [Expression ] = featuresExpr :: weightExpr :: Nil
548
593
549
594
override def update (state : SummarizerBuffer , row : InternalRow ): SummarizerBuffer = {
550
- val features = udt .deserialize(featuresExpr.eval(row))
595
+ val features = vectorUDT .deserialize(featuresExpr.eval(row))
551
596
val weight = weightExpr.eval(row).asInstanceOf [Double ]
552
597
state.add(features, weight)
553
598
state
@@ -591,7 +636,4 @@ private[ml] object SummaryBuilderImpl extends Logging {
591
636
override def prettyName : String = " aggregate_metrics"
592
637
593
638
}
594
-
595
- private [this ] val udt = new VectorUDT
596
-
597
639
}
0 commit comments