Skip to content

Commit d18c899

Browse files
committed
[SPARK-50929][ML][PYTHON][CONNECT] Support LDA on Connect
### What changes were proposed in this pull request? Support `LDA` on Connect ### Why are the changes needed? feature parity ### Does this PR introduce _any_ user-facing change? yes ### How was this patch tested? added tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #49679 from zhengruifeng/ml_connect_lda. Authored-by: Ruifeng Zheng <ruifengz@apache.org> Signed-off-by: Ruifeng Zheng <ruifengz@apache.org> (cherry picked from commit b6b00e8) Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
1 parent c4d3371 commit d18c899

File tree

6 files changed

+159
-1
lines changed

6 files changed

+159
-1
lines changed

mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ org.apache.spark.ml.regression.GBTRegressor
3737
org.apache.spark.ml.clustering.KMeans
3838
org.apache.spark.ml.clustering.BisectingKMeans
3939
org.apache.spark.ml.clustering.GaussianMixture
40+
org.apache.spark.ml.clustering.LDA
4041

4142
# recommendation
4243
org.apache.spark.ml.recommendation.ALS

mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ org.apache.spark.ml.regression.GBTRegressionModel
5353
org.apache.spark.ml.clustering.KMeansModel
5454
org.apache.spark.ml.clustering.BisectingKMeansModel
5555
org.apache.spark.ml.clustering.GaussianMixtureModel
56+
org.apache.spark.ml.clustering.DistributedLDAModel
57+
org.apache.spark.ml.clustering.LocalLDAModel
5658

5759
# recommendation
5860
org.apache.spark.ml.recommendation.ALSModel

mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,8 @@ class LocalLDAModel private[ml] (
617617
sparkSession: SparkSession)
618618
extends LDAModel(uid, vocabSize, sparkSession) {
619619

620+
private[ml] def this() = this(Identifiable.randomUID("lda"), -1, null, null)
621+
620622
oldLocalModel.setSeed(getSeed)
621623

622624
@Since("1.6.0")
@@ -713,6 +715,8 @@ class DistributedLDAModel private[ml] (
713715
private var oldLocalModelOption: Option[OldLocalLDAModel])
714716
extends LDAModel(uid, vocabSize, sparkSession) {
715717

718+
private[ml] def this() = this(Identifiable.randomUID("lda"), -1, null, null, None)
719+
716720
override private[clustering] def oldLocalModel: OldLocalLDAModel = {
717721
if (oldLocalModelOption.isEmpty) {
718722
oldLocalModelOption = Some(oldDistributedModel.toLocal)

python/pyspark/ml/clustering.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1511,6 +1511,7 @@ def logPerplexity(self, dataset: DataFrame) -> float:
15111511
return self._call_java("logPerplexity", dataset)
15121512

15131513
@since("2.0.0")
1514+
@try_remote_attribute_relation
15141515
def describeTopics(self, maxTermsPerTopic: int = 10) -> DataFrame:
15151516
"""
15161517
Return the topics described by their top-weighted terms.

python/pyspark/ml/tests/test_clustering.py

Lines changed: 138 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import numpy as np
2222

23-
from pyspark.ml.linalg import Vectors
23+
from pyspark.ml.linalg import Vectors, SparseVector
2424
from pyspark.sql import SparkSession
2525
from pyspark.ml.clustering import (
2626
KMeans,
@@ -32,6 +32,10 @@
3232
GaussianMixture,
3333
GaussianMixtureModel,
3434
GaussianMixtureSummary,
35+
LDA,
36+
LDAModel,
37+
LocalLDAModel,
38+
DistributedLDAModel,
3539
)
3640

3741

@@ -264,6 +268,139 @@ def test_gaussian_mixture(self):
264268
model2 = GaussianMixtureModel.load(d)
265269
self.assertEqual(str(model), str(model2))
266270

271+
def test_local_lda(self):
272+
spark = self.spark
273+
df = (
274+
spark.createDataFrame(
275+
[
276+
[1, Vectors.dense([0.0, 1.0])],
277+
[2, SparseVector(2, {0: 1.0})],
278+
],
279+
["id", "features"],
280+
)
281+
.coalesce(1)
282+
.sortWithinPartitions("id")
283+
)
284+
285+
lda = LDA(k=2, optimizer="online", seed=1)
286+
lda.setMaxIter(1)
287+
self.assertEqual(lda.getK(), 2)
288+
self.assertEqual(lda.getOptimizer(), "online")
289+
self.assertEqual(lda.getMaxIter(), 1)
290+
self.assertEqual(lda.getSeed(), 1)
291+
292+
model = lda.fit(df)
293+
self.assertEqual(lda.uid, model.uid)
294+
self.assertIsInstance(model, LDAModel)
295+
self.assertIsInstance(model, LocalLDAModel)
296+
self.assertNotIsInstance(model, DistributedLDAModel)
297+
self.assertFalse(model.isDistributed())
298+
299+
dc = model.estimatedDocConcentration()
300+
self.assertTrue(np.allclose(dc.toArray(), [0.5, 0.5], atol=1e-4), dc)
301+
topics = model.topicsMatrix()
302+
self.assertTrue(
303+
np.allclose(
304+
topics.toArray(), [[1.20296728, 1.15740442], [0.99357675, 1.02993164]], atol=1e-4
305+
),
306+
topics,
307+
)
308+
309+
ll = model.logLikelihood(df)
310+
self.assertTrue(np.allclose(ll, -3.2125122434040088, atol=1e-4), ll)
311+
lp = model.logPerplexity(df)
312+
self.assertTrue(np.allclose(lp, 1.6062561217020044, atol=1e-4), lp)
313+
dt = model.describeTopics()
314+
self.assertEqual(dt.columns, ["topic", "termIndices", "termWeights"])
315+
self.assertEqual(dt.count(), 2)
316+
317+
# LocalLDAModel specific methods
318+
self.assertEqual(model.vocabSize(), 2)
319+
320+
output = model.transform(df)
321+
expected_cols = ["id", "features", "topicDistribution"]
322+
self.assertEqual(output.columns, expected_cols)
323+
self.assertEqual(output.count(), 2)
324+
325+
# save & load
326+
with tempfile.TemporaryDirectory(prefix="local_lda") as d:
327+
lda.write().overwrite().save(d)
328+
lda2 = LDA.load(d)
329+
self.assertEqual(str(lda), str(lda2))
330+
331+
model.write().overwrite().save(d)
332+
model2 = LocalLDAModel.load(d)
333+
self.assertEqual(str(model), str(model2))
334+
335+
def test_distributed_lda(self):
336+
spark = self.spark
337+
df = (
338+
spark.createDataFrame(
339+
[
340+
[1, Vectors.dense([0.0, 1.0])],
341+
[2, SparseVector(2, {0: 1.0})],
342+
],
343+
["id", "features"],
344+
)
345+
.coalesce(1)
346+
.sortWithinPartitions("id")
347+
)
348+
349+
lda = LDA(k=2, optimizer="em", seed=1)
350+
lda.setMaxIter(1)
351+
352+
self.assertEqual(lda.getK(), 2)
353+
self.assertEqual(lda.getOptimizer(), "em")
354+
self.assertEqual(lda.getMaxIter(), 1)
355+
self.assertEqual(lda.getSeed(), 1)
356+
357+
model = lda.fit(df)
358+
self.assertEqual(lda.uid, model.uid)
359+
self.assertIsInstance(model, LDAModel)
360+
self.assertNotIsInstance(model, LocalLDAModel)
361+
self.assertIsInstance(model, DistributedLDAModel)
362+
363+
dc = model.estimatedDocConcentration()
364+
self.assertTrue(np.allclose(dc.toArray(), [26.0, 26.0], atol=1e-4), dc)
365+
topics = model.topicsMatrix()
366+
self.assertTrue(
367+
np.allclose(
368+
topics.toArray(), [[0.39149926, 0.60850074], [0.60991237, 0.39008763]], atol=1e-4
369+
),
370+
topics,
371+
)
372+
373+
ll = model.logLikelihood(df)
374+
self.assertTrue(np.allclose(ll, -3.719138517085772, atol=1e-4), ll)
375+
lp = model.logPerplexity(df)
376+
self.assertTrue(np.allclose(lp, 1.859569258542886, atol=1e-4), lp)
377+
378+
dt = model.describeTopics()
379+
self.assertEqual(dt.columns, ["topic", "termIndices", "termWeights"])
380+
self.assertEqual(dt.count(), 2)
381+
382+
# DistributedLDAModel specific methods
383+
ll = model.trainingLogLikelihood()
384+
self.assertTrue(np.allclose(ll, -1.3847360462201639, atol=1e-4), ll)
385+
lp = model.logPrior()
386+
self.assertTrue(np.allclose(lp, -69.59963186898915, atol=1e-4), lp)
387+
model.getCheckpointFiles()
388+
389+
output = model.transform(df)
390+
expected_cols = ["id", "features", "topicDistribution"]
391+
self.assertEqual(output.columns, expected_cols)
392+
self.assertEqual(output.count(), 2)
393+
394+
# save & load
395+
with tempfile.TemporaryDirectory(prefix="distributed_lda") as d:
396+
lda.write().overwrite().save(d)
397+
lda2 = LDA.load(d)
398+
self.assertEqual(str(lda), str(lda2))
399+
400+
model.write().overwrite().save(d)
401+
model2 = DistributedLDAModel.load(d)
402+
self.assertEqual(str(model), str(model2))
403+
267404

268405
class ClusteringTests(ClusteringTestsMixin, unittest.TestCase):
269406
def setUp(self) -> None:

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,19 @@ private[ml] object MLUtils {
594594
classOf[GaussianMixtureModel],
595595
Set("predict", "numFeatures", "weights", "gaussians", "predictProbability", "gaussiansDF")),
596596
(classOf[GaussianMixtureSummary], Set("probability", "probabilityCol", "logLikelihood")),
597+
(
598+
classOf[LDAModel],
599+
Set(
600+
"estimatedDocConcentration",
601+
"topicsMatrix",
602+
"isDistributed",
603+
"logLikelihood",
604+
"logPerplexity",
605+
"describeTopics")),
606+
(classOf[LocalLDAModel], Set("vocabSize")),
607+
(
608+
classOf[DistributedLDAModel],
609+
Set("trainingLogLikelihood", "logPrior", "getCheckpointFiles")),
597610

598611
// Recommendation Models
599612
(

0 commit comments

Comments
 (0)