File tree Expand file tree Collapse file tree 3 files changed +9
-12
lines changed
src/main/scala/com/massivedatascience/clusterer/ml Expand file tree Collapse file tree 3 files changed +9
-12
lines changed Original file line number Diff line number Diff line change @@ -336,13 +336,10 @@ class DPMeans(override val uid: String)
336336 )
337337
338338 // Check convergence (max center movement)
339- val maxMovement = centers
340- .zip(newCentersArray)
341- .map { case (old, neu) =>
342- kernel.divergence(Vectors .dense(old), Vectors .dense(neu))
343- }
344- .maxOption
345- .getOrElse(0.0 )
339+ val movements = centers.zip(newCentersArray).map { case (old, neu) =>
340+ kernel.divergence(Vectors .dense(old), Vectors .dense(neu))
341+ }
342+ val maxMovement = if (movements.isEmpty) 0.0 else movements.max
346343
347344 logInfo(f " Iteration $iter: max center movement = $maxMovement%.6f " )
348345
@@ -351,7 +348,7 @@ class DPMeans(override val uid: String)
351348 logInfo(s " Converged after $iter iterations (movement $maxMovement < tol $tolVal) " )
352349 }
353350
354- centers = ArrayBuffer .from (newCentersArray)
351+ centers = ArrayBuffer (newCentersArray : _* )
355352 }
356353
357354 assigned.unpersist()
Original file line number Diff line number Diff line change @@ -75,7 +75,7 @@ class ConstraintSet(constraints: Seq[Constraint]) extends Serializable {
7575 builder.getOrElseUpdate(ml.i, mutable.Set .empty) += ml.j
7676 builder.getOrElseUpdate(ml.j, mutable.Set .empty) += ml.i
7777 }
78- builder.view.mapValues(_ .toSet) .toMap
78+ builder.map { case (k, v) => k -> v .toSet } .toMap
7979 }
8080
8181 private val cannotLinkIndex : Map [Long , Set [Long ]] = {
@@ -84,7 +84,7 @@ class ConstraintSet(constraints: Seq[Constraint]) extends Serializable {
8484 builder.getOrElseUpdate(cl.i, mutable.Set .empty) += cl.j
8585 builder.getOrElseUpdate(cl.j, mutable.Set .empty) += cl.i
8686 }
87- builder.view.mapValues(_ .toSet) .toMap
87+ builder.map { case (k, v) => k -> v .toSet } .toMap
8888 }
8989
9090 // Constraint weights for penalty calculation
Original file line number Diff line number Diff line change @@ -61,8 +61,8 @@ private[df] class SECrossJoinAssignment extends AssignmentStrategy with Logging
6161 // Find minimum distance cluster for each point
6262 // Use window function for efficiency
6363 import org .apache .spark .sql .expressions .Window
64- import scala . collection . immutable . ArraySeq
65- val windowSpec = Window .partitionBy(ArraySeq .unsafeWrapArray(df.columns.map(col)) : _* )
64+ val partitionCols = df.columns.map(col)
65+ val windowSpec = Window .partitionBy(partitionCols : _* )
6666
6767 val withRank = withDistances
6868 .withColumn(" rank" , row_number().over(windowSpec.orderBy(" distance" )))
You can’t perform that action at this time.
0 commit comments