Skip to content

Commit a32c92c

Browse files
henrydavidgesrowen
authored andcommitted
[SPARK-28140][MLLIB][PYTHON] Accept DataFrames in RowMatrix and IndexedRowMatrix constructors
## What changes were proposed in this pull request? In both cases, the input `DataFrame` schema must contain only the information that's required for the matrix object, so a vector column in the case of `RowMatrix` and long and vector columns for `IndexedRowMatrix`. ## How was this patch tested? Unit tests that verify: - `RowMatrix` and `IndexedRowMatrix` can be created from `DataFrame`s - If the schema does not match expectations, we throw an `IllegalArgumentException` Please review https://spark.apache.org/contributing.html before opening a pull request. Closes apache#24953 from henrydavidge/row-matrix-df. Authored-by: Henry D <[email protected]> Signed-off-by: Sean Owen <[email protected]>
1 parent 019efaa commit a32c92c

File tree

3 files changed

+44
-4
lines changed

3 files changed

+44
-4
lines changed

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTree
5353
import org.apache.spark.mllib.util.{LinearDataGenerator, MLUtils}
5454
import org.apache.spark.rdd.RDD
5555
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
56+
import org.apache.spark.sql.types.LongType
5657
import org.apache.spark.storage.StorageLevel
5758
import org.apache.spark.util.Utils
5859

@@ -1142,12 +1143,21 @@ private[python] class PythonMLLibAPI extends Serializable {
11421143
new RowMatrix(rows.rdd, numRows, numCols)
11431144
}
11441145

1146+
def createRowMatrix(df: DataFrame, numRows: Long, numCols: Int): RowMatrix = {
1147+
require(df.schema.length == 1 && df.schema.head.dataType.getClass == classOf[VectorUDT],
1148+
"DataFrame must have a single vector type column")
1149+
new RowMatrix(df.rdd.map { case Row(vector: Vector) => vector }, numRows, numCols)
1150+
}
1151+
11451152
/**
11461153
* Wrapper around IndexedRowMatrix constructor.
11471154
*/
11481155
def createIndexedRowMatrix(rows: DataFrame, numRows: Long, numCols: Int): IndexedRowMatrix = {
11491156
// We use DataFrames for serialization of IndexedRows from Python,
11501157
// so map each Row in the DataFrame back to an IndexedRow.
1158+
require(rows.schema.length == 2 && rows.schema.head.dataType == LongType &&
1159+
rows.schema(1).dataType.getClass == classOf[VectorUDT],
1160+
"DataFrame must consist of a long type index column and a vector type column")
11511161
val indexedRows = rows.rdd.map {
11521162
case Row(index: Long, vector: Vector) => IndexedRow(index, vector)
11531163
}

python/pyspark/mllib/linalg/distributed.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
3131
from pyspark.mllib.linalg import _convert_to_vector, DenseMatrix, Matrix, QRDecomposition
3232
from pyspark.mllib.stat import MultivariateStatisticalSummary
33+
from pyspark.sql import DataFrame
3334
from pyspark.storagelevel import StorageLevel
3435

3536

@@ -57,7 +58,8 @@ class RowMatrix(DistributedMatrix):
5758
Represents a row-oriented distributed Matrix with no meaningful
5859
row indices.
5960
60-
:param rows: An RDD of vectors.
61+
:param rows: An RDD or DataFrame of vectors. If a DataFrame is provided, it must have a single
62+
vector typed column.
6163
:param numRows: Number of rows in the matrix. A non-positive
6264
value means unknown, at which point the number
6365
of rows will be determined by the number of
@@ -73,7 +75,7 @@ def __init__(self, rows, numRows=0, numCols=0):
7375
7476
Create a wrapper over a Java RowMatrix.
7577
76-
Publicly, we require that `rows` be an RDD. However, for
78+
Publicly, we require that `rows` be an RDD or DataFrame. However, for
7779
internal usage, `rows` can also be a Java RowMatrix
7880
object, in which case we can wrap it directly. This
7981
assists in clean matrix conversions.
@@ -94,6 +96,8 @@ def __init__(self, rows, numRows=0, numCols=0):
9496
if isinstance(rows, RDD):
9597
rows = rows.map(_convert_to_vector)
9698
java_matrix = callMLlibFunc("createRowMatrix", rows, long(numRows), int(numCols))
99+
elif isinstance(rows, DataFrame):
100+
java_matrix = callMLlibFunc("createRowMatrix", rows, long(numRows), int(numCols))
97101
elif (isinstance(rows, JavaObject)
98102
and rows.getClass().getSimpleName() == "RowMatrix"):
99103
java_matrix = rows
@@ -461,7 +465,8 @@ class IndexedRowMatrix(DistributedMatrix):
461465
"""
462466
Represents a row-oriented distributed Matrix with indexed rows.
463467
464-
:param rows: An RDD of IndexedRows or (long, vector) tuples.
468+
:param rows: An RDD of IndexedRows or (long, vector) tuples or a DataFrame consisting of a
469+
long typed column of indices and a vector typed column.
465470
:param numRows: Number of rows in the matrix. A non-positive
466471
value means unknown, at which point the number
467472
of rows will be determined by the max row
@@ -477,7 +482,7 @@ def __init__(self, rows, numRows=0, numCols=0):
477482
478483
Create a wrapper over a Java IndexedRowMatrix.
479484
480-
Publicly, we require that `rows` be an RDD. However, for
485+
Publicly, we require that `rows` be an RDD or DataFrame. However, for
481486
internal usage, `rows` can also be a Java IndexedRowMatrix
482487
object, in which case we can wrap it directly. This
483488
assists in clean matrix conversions.
@@ -506,6 +511,8 @@ def __init__(self, rows, numRows=0, numCols=0):
506511
# IndexedRows on the Scala side.
507512
java_matrix = callMLlibFunc("createIndexedRowMatrix", rows.toDF(),
508513
long(numRows), int(numCols))
514+
elif isinstance(rows, DataFrame):
515+
java_matrix = callMLlibFunc("createIndexedRowMatrix", rows, long(numRows), int(numCols))
509516
elif (isinstance(rows, JavaObject)
510517
and rows.getClass().getSimpleName() == "IndexedRowMatrix"):
511518
java_matrix = rows

python/pyspark/mllib/tests/test_linalg.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,15 @@
2525
from pyspark.serializers import PickleSerializer
2626
from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector, \
2727
DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT
28+
from pyspark.mllib.linalg.distributed import RowMatrix, IndexedRowMatrix
2829
from pyspark.mllib.regression import LabeledPoint
30+
from pyspark.sql import Row
2931
from pyspark.testing.mllibutils import MLlibTestCase
3032
from pyspark.testing.utils import have_scipy
3133

34+
if sys.version >= '3':
35+
long = int
36+
3237

3338
class VectorTests(MLlibTestCase):
3439

@@ -431,6 +436,24 @@ def test_infer_schema(self):
431436
else:
432437
raise TypeError("expecting a vector but got %r of type %r" % (v, type(v)))
433438

439+
def test_row_matrix_from_dataframe(self):
440+
from pyspark.sql.utils import IllegalArgumentException
441+
df = self.spark.createDataFrame([Row(Vectors.dense(1))])
442+
row_matrix = RowMatrix(df)
443+
self.assertEqual(row_matrix.numRows(), 1)
444+
self.assertEqual(row_matrix.numCols(), 1)
445+
with self.assertRaises(IllegalArgumentException):
446+
RowMatrix(df.selectExpr("'monkey'"))
447+
448+
def test_indexed_row_matrix_from_dataframe(self):
449+
from pyspark.sql.utils import IllegalArgumentException
450+
df = self.spark.createDataFrame([Row(long(0), Vectors.dense(1))])
451+
matrix = IndexedRowMatrix(df)
452+
self.assertEqual(matrix.numRows(), 1)
453+
self.assertEqual(matrix.numCols(), 1)
454+
with self.assertRaises(IllegalArgumentException):
455+
IndexedRowMatrix(df.drop("_1"))
456+
434457

435458
class MatrixUDTTests(MLlibTestCase):
436459

0 commit comments

Comments
 (0)