|
| 1 | +package com.massivedatascience.clusterer.ml |
| 2 | + |
| 3 | +import org.apache.spark.internal.Logging |
| 4 | +import org.apache.spark.ml.linalg.Vector |
| 5 | +import org.apache.spark.ml.param._ |
| 6 | +import org.apache.spark.ml.util._ |
| 7 | +import org.apache.spark.sql.Dataset |
| 8 | + |
| 9 | +/** Parameters for CLARA (Clustering Large Applications). */ |
| 10 | +trait CLARAParams extends KMedoidsParams { |
| 11 | + |
| 12 | + /** Number of samples to draw from the dataset. Each sample is clustered independently and the |
| 13 | + * best result is selected. |
| 14 | + */ |
| 15 | + final val numSamples = new IntParam( |
| 16 | + this, |
| 17 | + "numSamples", |
| 18 | + "Number of samples to draw", |
| 19 | + ParamValidators.gt(0) |
| 20 | + ) |
| 21 | + |
| 22 | + def getNumSamples: Int = $(numSamples) |
| 23 | + |
| 24 | + /** Sample size for each sample. Default: 40 + 2*k (as recommended in the original CLARA paper) |
| 25 | + * Set to -1 for automatic sizing (40 + 2*k) |
| 26 | + */ |
| 27 | + final val sampleSize = new IntParam( |
| 28 | + this, |
| 29 | + "sampleSize", |
| 30 | + "Sample size for each sample (-1 for auto)", |
| 31 | + (value: Int) => value == -1 || value > 0 |
| 32 | + ) |
| 33 | + |
| 34 | + def getSampleSize: Int = $(sampleSize) |
| 35 | + |
| 36 | + setDefault( |
| 37 | + numSamples -> 5, |
| 38 | + sampleSize -> -1 // -1 means auto (40 + 2*k) |
| 39 | + ) |
| 40 | +} |
| 41 | + |
| 42 | +/** CLARA (Clustering Large Applications) - Sampling-based K-Medoids for large datasets. |
| 43 | + * |
| 44 | + * CLARA is a more scalable version of PAM that works on large datasets by: |
| 45 | + * 1. Drawing multiple samples from the full dataset 2. Running PAM on each sample 3. For each |
| 46 | + * sample result, computing the cost on the full dataset 4. Selecting the medoids with the |
| 47 | + * lowest total cost |
| 48 | + * |
| 49 | + * Time Complexity: O(numSamples * k(s-k)²) where s is the sample size |
| 50 | + * |
| 51 | + * CLARA is recommended when: |
| 52 | + * - Dataset has > 10,000 points |
| 53 | + * - PAM is too slow due to O(k(n-k)²) complexity |
| 54 | + * - Good approximation to PAM is acceptable |
| 55 | + * |
| 56 | + * Example usage: |
| 57 | + * {{{ |
| 58 | + * val clara = new CLARA() |
| 59 | + * .setK(3) |
| 60 | + * .setNumSamples(10) |
| 61 | + * .setSampleSize(100) |
| 62 | + * .setMaxIter(20) |
| 63 | + * .setDistanceFunction("manhattan") |
| 64 | + * |
| 65 | + * val model = clara.fit(largeDataset) |
| 66 | + * val predictions = model.transform(largeDataset) |
| 67 | + * }}} |
| 68 | + * |
| 69 | + * @param uid |
| 70 | + * unique identifier |
| 71 | + */ |
| 72 | +class CLARA(override val uid: String) |
| 73 | + extends KMedoids(uid) |
| 74 | + with CLARAParams |
| 75 | + with DefaultParamsWritable |
| 76 | + with Logging { |
| 77 | + |
| 78 | + def this() = this(Identifiable.randomUID("clara")) |
| 79 | + |
| 80 | + override def fit(dataset: Dataset[_]): KMedoidsModel = { |
| 81 | + transformSchema(dataset.schema, logging = true) |
| 82 | + |
| 83 | + val df = dataset.toDF() |
| 84 | + val data = df |
| 85 | + .select($(featuresCol)) |
| 86 | + .rdd |
| 87 | + .map { row => |
| 88 | + row.getAs |
| 89 | + } |
| 90 | + .collect() |
| 91 | + |
| 92 | + val n = data.length |
| 93 | + val numClusters = $(k) |
| 94 | + |
| 95 | + // Determine sample size |
| 96 | + val actualSampleSize = if ($(sampleSize) == -1) { |
| 97 | + math.min(40 + 2 * numClusters, n) |
| 98 | + } else { |
| 99 | + math.min($(sampleSize), n) |
| 100 | + } |
| 101 | + |
| 102 | + logInfo(s"CLARA with k=$numClusters, numSamples=${$(numSamples)}, sampleSize=$actualSampleSize") |
| 103 | + logInfo(s"Dataset size: $n points") |
| 104 | + |
| 105 | + if (actualSampleSize >= n * 0.9) { |
| 106 | + logWarning( |
| 107 | + s"Sample size ($actualSampleSize) is >= 90% of dataset ($n). Consider using PAM instead of CLARA." |
| 108 | + ) |
| 109 | + } |
| 110 | + |
| 111 | + // Create distance function |
| 112 | + val distFn = createDistanceFunction($(distanceFunction)) |
| 113 | + |
| 114 | + var bestMedoidIndices: Array[Int] = null |
| 115 | + var bestCost = Double.PositiveInfinity |
| 116 | + |
| 117 | + // Try multiple samples |
| 118 | + val rng = new scala.util.Random($(seed)) |
| 119 | + |
| 120 | + (0 until $(numSamples)).foreach { sampleIdx => |
| 121 | + logInfo(s"Processing sample ${sampleIdx + 1}/${$(numSamples)}") |
| 122 | + |
| 123 | + // Draw random sample |
| 124 | + val sampleIndices = rng.shuffle((0 until n).toList).take(actualSampleSize).toArray |
| 125 | + val sample = sampleIndices.map(data) |
| 126 | + |
| 127 | + // Run PAM on sample |
| 128 | + val sampleMedoidIndices = buildPhase(sample, numClusters, distFn, $(seed) + sampleIdx) |
| 129 | + val finalSampleMedoidIndices = swapPhase(sample, sampleMedoidIndices, $(maxIter), distFn) |
| 130 | + |
| 131 | + // Map sample medoid indices back to original dataset indices |
| 132 | + val originalMedoidIndices = finalSampleMedoidIndices.map(sampleIndices) |
| 133 | + |
| 134 | + // Compute cost on FULL dataset |
| 135 | + var totalCost = 0.0 |
| 136 | + data.foreach { point => |
| 137 | + val minDist = originalMedoidIndices.map(medIdx => distFn(point, data(medIdx))).min |
| 138 | + totalCost += minDist |
| 139 | + } |
| 140 | + |
| 141 | + logInfo(f"Sample ${sampleIdx + 1} cost on full dataset: $totalCost%.4f") |
| 142 | + |
| 143 | + // Keep best result |
| 144 | + if (totalCost < bestCost) { |
| 145 | + bestCost = totalCost |
| 146 | + bestMedoidIndices = originalMedoidIndices |
| 147 | + logInfo(f"New best cost: $bestCost%.4f") |
| 148 | + } |
| 149 | + } |
| 150 | + |
| 151 | + logInfo(f"CLARA completed. Best cost: $bestCost%.4f") |
| 152 | + logInfo(s"Best medoid indices: ${bestMedoidIndices.mkString(", ")}") |
| 153 | + |
| 154 | + // Create model |
| 155 | + val medoidVectors = bestMedoidIndices.map(data) |
| 156 | + new KMedoidsModel(uid, medoidVectors, bestMedoidIndices, $(distanceFunction)).setParent(this) |
| 157 | + } |
| 158 | + |
| 159 | + // Parameter setters |
| 160 | + def setNumSamples(value: Int): this.type = set(numSamples, value) |
| 161 | + def setSampleSize(value: Int): this.type = set(sampleSize, value) |
| 162 | + |
| 163 | + override def copy(extra: ParamMap): CLARA = defaultCopy(extra) |
| 164 | +} |
| 165 | + |
| 166 | +object CLARA extends DefaultParamsReadable[CLARA] { |
| 167 | + override def load(path: String): CLARA = super.load(path) |
| 168 | +} |
0 commit comments