Skip to content

Commit 1ca3c50

Browse files
WeichenXu123jkbradley
authored andcommitted
[SPARK-21741][ML][PYSPARK] Python API for DataFrame-based multivariate summarizer
## What changes were proposed in this pull request? Python API for DataFrame-based multivariate summarizer. ## How was this patch tested? doctest added. Author: WeichenXu <[email protected]> Closes apache#20695 from WeichenXu123/py_summarizer.
1 parent f39e82c commit 1ca3c50

File tree

1 file changed

+192
-1
lines changed

1 file changed

+192
-1
lines changed

python/pyspark/ml/stat.py

Lines changed: 192 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919

2020
from pyspark import since, SparkContext
2121
from pyspark.ml.common import _java2py, _py2java
22-
from pyspark.ml.wrapper import _jvm
22+
from pyspark.ml.wrapper import JavaWrapper, _jvm
23+
from pyspark.sql.column import Column, _to_seq
24+
from pyspark.sql.functions import lit
2325

2426

2527
class ChiSquareTest(object):
@@ -195,6 +197,195 @@ def test(dataset, sampleCol, distName, *params):
195197
_jvm().PythonUtils.toSeq(params)))
196198

197199

200+
class Summarizer(object):
201+
"""
202+
.. note:: Experimental
203+
204+
Tools for vectorized statistics on MLlib Vectors.
205+
The methods in this package provide various statistics for Vectors contained inside DataFrames.
206+
This class lets users pick the statistics they would like to extract for a given column.
207+
208+
>>> from pyspark.ml.stat import Summarizer
209+
>>> from pyspark.sql import Row
210+
>>> from pyspark.ml.linalg import Vectors
211+
>>> summarizer = Summarizer.metrics("mean", "count")
212+
>>> df = sc.parallelize([Row(weight=1.0, features=Vectors.dense(1.0, 1.0, 1.0)),
213+
... Row(weight=0.0, features=Vectors.dense(1.0, 2.0, 3.0))]).toDF()
214+
>>> df.select(summarizer.summary(df.features, df.weight)).show(truncate=False)
215+
+-----------------------------------+
216+
|aggregate_metrics(features, weight)|
217+
+-----------------------------------+
218+
|[[1.0,1.0,1.0], 1] |
219+
+-----------------------------------+
220+
<BLANKLINE>
221+
>>> df.select(summarizer.summary(df.features)).show(truncate=False)
222+
+--------------------------------+
223+
|aggregate_metrics(features, 1.0)|
224+
+--------------------------------+
225+
|[[1.0,1.5,2.0], 2] |
226+
+--------------------------------+
227+
<BLANKLINE>
228+
>>> df.select(Summarizer.mean(df.features, df.weight)).show(truncate=False)
229+
+--------------+
230+
|mean(features)|
231+
+--------------+
232+
|[1.0,1.0,1.0] |
233+
+--------------+
234+
<BLANKLINE>
235+
>>> df.select(Summarizer.mean(df.features)).show(truncate=False)
236+
+--------------+
237+
|mean(features)|
238+
+--------------+
239+
|[1.0,1.5,2.0] |
240+
+--------------+
241+
<BLANKLINE>
242+
243+
.. versionadded:: 2.4.0
244+
245+
"""
246+
@staticmethod
247+
@since("2.4.0")
248+
def mean(col, weightCol=None):
249+
"""
250+
return a column of mean summary
251+
"""
252+
return Summarizer._get_single_metric(col, weightCol, "mean")
253+
254+
@staticmethod
255+
@since("2.4.0")
256+
def variance(col, weightCol=None):
257+
"""
258+
return a column of variance summary
259+
"""
260+
return Summarizer._get_single_metric(col, weightCol, "variance")
261+
262+
@staticmethod
263+
@since("2.4.0")
264+
def count(col, weightCol=None):
265+
"""
266+
return a column of count summary
267+
"""
268+
return Summarizer._get_single_metric(col, weightCol, "count")
269+
270+
@staticmethod
271+
@since("2.4.0")
272+
def numNonZeros(col, weightCol=None):
273+
"""
274+
return a column of numNonZero summary
275+
"""
276+
return Summarizer._get_single_metric(col, weightCol, "numNonZeros")
277+
278+
@staticmethod
279+
@since("2.4.0")
280+
def max(col, weightCol=None):
281+
"""
282+
return a column of max summary
283+
"""
284+
return Summarizer._get_single_metric(col, weightCol, "max")
285+
286+
@staticmethod
287+
@since("2.4.0")
288+
def min(col, weightCol=None):
289+
"""
290+
return a column of min summary
291+
"""
292+
return Summarizer._get_single_metric(col, weightCol, "min")
293+
294+
@staticmethod
295+
@since("2.4.0")
296+
def normL1(col, weightCol=None):
297+
"""
298+
return a column of normL1 summary
299+
"""
300+
return Summarizer._get_single_metric(col, weightCol, "normL1")
301+
302+
@staticmethod
303+
@since("2.4.0")
304+
def normL2(col, weightCol=None):
305+
"""
306+
return a column of normL2 summary
307+
"""
308+
return Summarizer._get_single_metric(col, weightCol, "normL2")
309+
310+
@staticmethod
311+
def _check_param(featuresCol, weightCol):
312+
if weightCol is None:
313+
weightCol = lit(1.0)
314+
if not isinstance(featuresCol, Column) or not isinstance(weightCol, Column):
315+
raise TypeError("featureCol and weightCol should be a Column")
316+
return featuresCol, weightCol
317+
318+
@staticmethod
319+
def _get_single_metric(col, weightCol, metric):
320+
col, weightCol = Summarizer._check_param(col, weightCol)
321+
return Column(JavaWrapper._new_java_obj("org.apache.spark.ml.stat.Summarizer." + metric,
322+
col._jc, weightCol._jc))
323+
324+
@staticmethod
325+
@since("2.4.0")
326+
def metrics(*metrics):
327+
"""
328+
Given a list of metrics, provides a builder that it turns computes metrics from a column.
329+
330+
See the documentation of [[Summarizer]] for an example.
331+
332+
The following metrics are accepted (case sensitive):
333+
- mean: a vector that contains the coefficient-wise mean.
334+
- variance: a vector tha contains the coefficient-wise variance.
335+
- count: the count of all vectors seen.
336+
- numNonzeros: a vector with the number of non-zeros for each coefficients
337+
- max: the maximum for each coefficient.
338+
- min: the minimum for each coefficient.
339+
- normL2: the Euclidian norm for each coefficient.
340+
- normL1: the L1 norm of each coefficient (sum of the absolute values).
341+
342+
:param metrics:
343+
metrics that can be provided.
344+
:return:
345+
an object of :py:class:`pyspark.ml.stat.SummaryBuilder`
346+
347+
Note: Currently, the performance of this interface is about 2x~3x slower then using the RDD
348+
interface.
349+
"""
350+
sc = SparkContext._active_spark_context
351+
js = JavaWrapper._new_java_obj("org.apache.spark.ml.stat.Summarizer.metrics",
352+
_to_seq(sc, metrics))
353+
return SummaryBuilder(js)
354+
355+
356+
class SummaryBuilder(JavaWrapper):
357+
"""
358+
.. note:: Experimental
359+
360+
A builder object that provides summary statistics about a given column.
361+
362+
Users should not directly create such builders, but instead use one of the methods in
363+
:py:class:`pyspark.ml.stat.Summarizer`
364+
365+
.. versionadded:: 2.4.0
366+
367+
"""
368+
def __init__(self, jSummaryBuilder):
369+
super(SummaryBuilder, self).__init__(jSummaryBuilder)
370+
371+
@since("2.4.0")
372+
def summary(self, featuresCol, weightCol=None):
373+
"""
374+
Returns an aggregate object that contains the summary of the column with the requested
375+
metrics.
376+
377+
:param featuresCol:
378+
a column that contains features Vector object.
379+
:param weightCol:
380+
a column that contains weight value. Default weight is 1.0.
381+
:return:
382+
an aggregate column that contains the statistics. The exact content of this
383+
structure is determined during the creation of the builder.
384+
"""
385+
featuresCol, weightCol = Summarizer._check_param(featuresCol, weightCol)
386+
return Column(self._java_obj.summary(featuresCol._jc, weightCol._jc))
387+
388+
198389
if __name__ == "__main__":
199390
import doctest
200391
import pyspark.ml.stat

0 commit comments

Comments
 (0)