Skip to content

Commit 5496e98

Browse files
committed
[SPARK-30109][ML] PCA use BLAS.gemv for sparse vectors
### What changes were proposed in this pull request? When PCA was first impled in [SPARK-5521](https://issues.apache.org/jira/browse/SPARK-5521), at that time Matrix.multiply(BLAS.gemv internally) did not support sparse vector. So worked around it by applying a sparse matrix multiplication. Since [SPARK-7681](https://issues.apache.org/jira/browse/SPARK-7681), BLAS.gemv supported sparse vector. So we can directly use Matrix.multiply now. ### Why are the changes needed? for simplity ### Does this PR introduce any user-facing change? No ### How was this patch tested? existing testsuites Closes apache#26745 from zhengruifeng/pca_mul. Authored-by: zhengruifeng <[email protected]> Signed-off-by: zhengruifeng <[email protected]>
1 parent 3dd3a62 commit 5496e98

File tree

2 files changed

+3
-28
lines changed

2 files changed

+3
-28
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -147,22 +147,8 @@ class PCAModel private[ml] (
147147
override def transform(dataset: Dataset[_]): DataFrame = {
148148
transformSchema(dataset.schema, logging = true)
149149

150-
val func = { vector: Vector =>
151-
vector match {
152-
case dv: DenseVector =>
153-
pc.transpose.multiply(dv)
154-
case SparseVector(size, indices, values) =>
155-
/* SparseVector -> single row SparseMatrix */
156-
val sm = Matrices.sparse(size, 1, Array(0, indices.length), indices, values).transpose
157-
val projection = sm.multiply(pc)
158-
Vectors.dense(projection.values)
159-
case _ =>
160-
throw new IllegalArgumentException("Unsupported vector format. Expected " +
161-
s"SparseVector or DenseVector. Instead got: ${vector.getClass}")
162-
}
163-
}
164-
165-
val transformer = udf(func)
150+
val transposed = pc.transpose
151+
val transformer = udf { vector: Vector => transposed.multiply(vector) }
166152
dataset.withColumn($(outputCol), transformer(col($(inputCol))))
167153
}
168154

mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -111,18 +111,7 @@ class PCAModel private[spark] (
111111
*/
112112
@Since("1.4.0")
113113
override def transform(vector: Vector): Vector = {
114-
vector match {
115-
case dv: DenseVector =>
116-
pc.transpose.multiply(dv)
117-
case SparseVector(size, indices, values) =>
118-
/* SparseVector -> single row SparseMatrix */
119-
val sm = Matrices.sparse(size, 1, Array(0, indices.length), indices, values).transpose
120-
val projection = sm.multiply(pc)
121-
Vectors.dense(projection.values)
122-
case _ =>
123-
throw new IllegalArgumentException("Unsupported vector format. Expected " +
124-
s"SparseVector or DenseVector. Instead got: ${vector.getClass}")
125-
}
114+
pc.transpose.multiply(vector)
126115
}
127116
}
128117

0 commit comments

Comments
 (0)