|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "attachments": {}, |
| 5 | + "cell_type": "markdown", |
| 6 | + "metadata": {}, |
| 7 | + "source": [ |
| 8 | + "We will learn how to perform clustering in this section. First, we will write the evaluation function, which provides various clustering metrics. For each metric, the closer the value is to 1, the better.\n" |
| 9 | + ] |
| 10 | + }, |
| 11 | + { |
| 12 | + "cell_type": "code", |
| 13 | + "execution_count": 2, |
| 14 | + "metadata": {}, |
| 15 | + "outputs": [], |
| 16 | + "source": [ |
| 17 | + "import numpy as np\n", |
| 18 | + "from sklearn.metrics import (\n", |
| 19 | + " normalized_mutual_info_score,\n", |
| 20 | + " adjusted_rand_score,\n", |
| 21 | + " precision_score,\n", |
| 22 | + " recall_score,\n", |
| 23 | + " f1_score\n", |
| 24 | + ")\n", |
| 25 | + "from scipy.optimize import linear_sum_assignment\n", |
| 26 | + "from collections import Counter\n", |
| 27 | + "\n", |
| 28 | + "def compute_label_alignment(y_true, y_pred):\n", |
| 29 | + " y_true = np.asarray(y_true)\n", |
| 30 | + " y_pred = np.asarray(y_pred)\n", |
| 31 | + " D = max(y_pred.max(), y_true.max()) + 1\n", |
| 32 | + " w = np.zeros((D, D), dtype=np.int64)\n", |
| 33 | + " for i in range(y_pred.size):\n", |
| 34 | + " w[y_pred[i], y_true[i]] += 1\n", |
| 35 | + " row_ind, col_ind = linear_sum_assignment(w.max() - w)\n", |
| 36 | + " mapping = {row: col for row, col in zip(row_ind, col_ind)}\n", |
| 37 | + " y_pred_aligned = np.array([mapping[label] for label in y_pred])\n", |
| 38 | + " acc = sum(w[i, j] for i, j in zip(row_ind, col_ind)) / y_pred.size\n", |
| 39 | + " return acc, y_pred_aligned\n", |
| 40 | + "\n", |
| 41 | + "def purity_score(y_true, y_pred):\n", |
| 42 | + " y_true = np.asarray(y_true)\n", |
| 43 | + " y_pred = np.asarray(y_pred)\n", |
| 44 | + " total = 0\n", |
| 45 | + " for cluster in np.unique(y_pred):\n", |
| 46 | + " indices = np.where(y_pred == cluster)[0]\n", |
| 47 | + " true_labels = y_true[indices]\n", |
| 48 | + " most_common = Counter(true_labels).most_common(1)\n", |
| 49 | + " if most_common:\n", |
| 50 | + " total += most_common[0][1]\n", |
| 51 | + " return total / len(y_true)\n", |
| 52 | + "\n", |
| 53 | + "def evaluate(y_true, y_pred, method='macro'):\n", |
| 54 | + " y_true = np.asarray(y_true)\n", |
| 55 | + " y_pred = np.asarray(y_pred)\n", |
| 56 | + "\n", |
| 57 | + " # ACC & aligned labels\n", |
| 58 | + " acc, y_pred_aligned = compute_label_alignment(y_true, y_pred)\n", |
| 59 | + "\n", |
| 60 | + " # Metrics\n", |
| 61 | + " nmi = normalized_mutual_info_score(y_true, y_pred)\n", |
| 62 | + " purity = purity_score(y_true, y_pred)\n", |
| 63 | + " precision = precision_score(y_true, y_pred_aligned, average=method, zero_division=0)\n", |
| 64 | + " recall = recall_score(y_true, y_pred_aligned, average=method, zero_division=0)\n", |
| 65 | + " f1 = f1_score(y_true, y_pred_aligned, average=method, zero_division=0)\n", |
| 66 | + " ari = adjusted_rand_score(y_true, y_pred)\n", |
| 67 | + "\n", |
| 68 | + " return np.array([acc, nmi, purity, f1, precision, recall, ari])\n" |
| 69 | + ] |
| 70 | + }, |
| 71 | + { |
| 72 | + "cell_type": "markdown", |
| 73 | + "metadata": {}, |
| 74 | + "source": [ |
| 75 | + "Import the necessary packages.\n" |
| 76 | + ] |
| 77 | + }, |
| 78 | + { |
| 79 | + "cell_type": "code", |
| 80 | + "execution_count": 3, |
| 81 | + "metadata": {}, |
| 82 | + "outputs": [], |
| 83 | + "source": [ |
| 84 | + "import torch\n", |
| 85 | + "from manify.manifolds import ProductManifold\n", |
| 86 | + "from manify.clustering.fuzzy_kmeans import RiemannianFuzzyKMeans\n", |
| 87 | + "import numpy as np" |
| 88 | + ] |
| 89 | + }, |
| 90 | + { |
| 91 | + "cell_type": "markdown", |
| 92 | + "metadata": {}, |
| 93 | + "source": [ |
| 94 | + "First, generate a **Product Manifold** using the following method." |
| 95 | + ] |
| 96 | + }, |
| 97 | + { |
| 98 | + "cell_type": "code", |
| 99 | + "execution_count": 4, |
| 100 | + "metadata": {}, |
| 101 | + "outputs": [], |
| 102 | + "source": [ |
| 103 | + "# 1. Define the signature: a 3-factor manifold\n", |
| 104 | + "import numpy as np\n", |
| 105 | + "# (curvature, dimension)\n", |
| 106 | + "signature = [\n", |
| 107 | + " (0.0, 4), # R^2 (Euclidean space)\n", |
| 108 | + " (1.0, 4), # S^2 (Spherical space)\n", |
| 109 | + " (-1.0, 4), # H^2 (Hyperbolic space)\n", |
| 110 | + "]\n", |
| 111 | + "\n", |
| 112 | + "# 2. Construct the ProductManifold (without stereographic projection)\n", |
| 113 | + "P = ProductManifold(signature, device=\"cpu\", stereographic=False)" |
| 114 | + ] |
| 115 | + }, |
| 116 | + { |
| 117 | + "cell_type": "code", |
| 118 | + "execution_count": 5, |
| 119 | + "metadata": {}, |
| 120 | + "outputs": [], |
| 121 | + "source": [ |
| 122 | + "#setting param\n", |
| 123 | + "n_clusters = 3\n", |
| 124 | + "seed = 0\n", |
| 125 | + "opt = 'adan'\n", |
| 126 | + "lr = .01\n", |
| 127 | + "tol = 1e-6" |
| 128 | + ] |
| 129 | + }, |
| 130 | + { |
| 131 | + "cell_type": "code", |
| 132 | + "execution_count": 6, |
| 133 | + "metadata": {}, |
| 134 | + "outputs": [], |
| 135 | + "source": [ |
| 136 | + "# 3. Generate data using gaussian_mixture\n", |
| 137 | + "# - num_points=500: sample 500 points\n", |
| 138 | + "# - num_classes=n_clusters: generate n_clusters class labels (for clustering)\n", |
| 139 | + "# - seed=seed: fix the random seed for reproducibility\n", |
| 140 | + "X, y_true = P.gaussian_mixture(\n", |
| 141 | + " num_points=500,\n", |
| 142 | + " num_classes=n_clusters,\n", |
| 143 | + " seed=seed,\n", |
| 144 | + " task=\"classification\",\n", |
| 145 | + " cov_scale_points=.1 # <--- try decreasing this value\n", |
| 146 | + ")\n", |
| 147 | + "y_true = np.array(y_true)" |
| 148 | + ] |
| 149 | + }, |
| 150 | + { |
| 151 | + "cell_type": "markdown", |
| 152 | + "metadata": {}, |
| 153 | + "source": [ |
| 154 | + "Call the `RiemannianFuzzyKMeans` algorithm from the `fuzzy_kmeans` module in the `manify` clustering package to perform clustering on a manifold." |
| 155 | + ] |
| 156 | + }, |
| 157 | + { |
| 158 | + "cell_type": "code", |
| 159 | + "execution_count": 7, |
| 160 | + "metadata": {}, |
| 161 | + "outputs": [ |
| 162 | + { |
| 163 | + "name": "stdout", |
| 164 | + "output_type": "stream", |
| 165 | + "text": [ |
| 166 | + "RFK iter 1, loss=1911.6559\n", |
| 167 | + "RFK iter 2, loss=1909.2720\n", |
| 168 | + "RFK iter 3, loss=1907.2013\n", |
| 169 | + "RFK iter 4, loss=1905.3054\n", |
| 170 | + "RFK iter 5, loss=1903.5615\n", |
| 171 | + "RFK iter 6, loss=1901.9583\n", |
| 172 | + "RFK iter 7, loss=1900.4875\n", |
| 173 | + "RFK iter 8, loss=1899.1450\n", |
| 174 | + "RFK iter 9, loss=1897.9203\n", |
| 175 | + "RFK iter 10, loss=1896.8094\n", |
| 176 | + "RFK iter 11, loss=1895.7996\n", |
| 177 | + "RFK iter 12, loss=1894.8835\n", |
| 178 | + "RFK iter 13, loss=1894.0524\n", |
| 179 | + "RFK iter 14, loss=1893.2966\n", |
| 180 | + "RFK iter 15, loss=1892.6089\n", |
| 181 | + "RFK iter 16, loss=1891.9822\n", |
| 182 | + "RFK iter 17, loss=1891.4094\n", |
| 183 | + "RFK iter 18, loss=1890.8866\n", |
| 184 | + "RFK iter 19, loss=1890.4080\n", |
| 185 | + "RFK iter 20, loss=1889.9674\n", |
| 186 | + "RFK iter 21, loss=1889.5638\n", |
| 187 | + "RFK iter 22, loss=1889.1930\n", |
| 188 | + "RFK iter 23, loss=1888.8501\n", |
| 189 | + "RFK iter 24, loss=1888.5347\n", |
| 190 | + "RFK iter 25, loss=1888.2428\n", |
| 191 | + "RFK iter 26, loss=1887.9733\n", |
| 192 | + "RFK iter 27, loss=1887.7229\n", |
| 193 | + "RFK iter 28, loss=1887.4913\n", |
| 194 | + "RFK iter 29, loss=1887.2771\n", |
| 195 | + "RFK iter 30, loss=1887.0775\n", |
| 196 | + "RFK iter 31, loss=1886.8916\n", |
| 197 | + "RFK iter 32, loss=1886.7195\n", |
| 198 | + "RFK iter 33, loss=1886.5583\n", |
| 199 | + "RFK iter 34, loss=1886.4092\n", |
| 200 | + "RFK iter 35, loss=1886.2698\n", |
| 201 | + "RFK iter 36, loss=1886.1398\n", |
| 202 | + "RFK iter 37, loss=1886.0182\n", |
| 203 | + "RFK iter 38, loss=1885.9049\n", |
| 204 | + "RFK iter 39, loss=1885.7992\n", |
| 205 | + "RFK iter 40, loss=1885.7002\n", |
| 206 | + "RFK iter 41, loss=1885.6074\n", |
| 207 | + "RFK iter 42, loss=1885.5209\n", |
| 208 | + "RFK iter 43, loss=1885.4402\n", |
| 209 | + "RFK iter 44, loss=1885.3644\n", |
| 210 | + "RFK iter 45, loss=1885.2936\n", |
| 211 | + "RFK iter 46, loss=1885.2272\n", |
| 212 | + "RFK iter 47, loss=1885.1655\n", |
| 213 | + "RFK iter 48, loss=1885.1078\n", |
| 214 | + "RFK iter 49, loss=1885.0532\n", |
| 215 | + "RFK iter 50, loss=1885.0027\n", |
| 216 | + "RFK iter 51, loss=1884.9554\n", |
| 217 | + "RFK iter 52, loss=1884.9111\n", |
| 218 | + "RFK iter 53, loss=1884.8702\n", |
| 219 | + "RFK iter 54, loss=1884.8313\n", |
| 220 | + "RFK iter 55, loss=1884.7960\n", |
| 221 | + "RFK iter 56, loss=1884.7618\n", |
| 222 | + "RFK iter 57, loss=1884.7318\n", |
| 223 | + "RFK iter 58, loss=1884.7024\n", |
| 224 | + "RFK iter 59, loss=1884.6760\n", |
| 225 | + "RFK iter 60, loss=1884.6500\n", |
| 226 | + "RFK iter 61, loss=1884.6274\n", |
| 227 | + "RFK iter 62, loss=1884.6057\n", |
| 228 | + "RFK iter 63, loss=1884.5859\n", |
| 229 | + "RFK iter 64, loss=1884.5674\n", |
| 230 | + "RFK iter 65, loss=1884.5500\n", |
| 231 | + "RFK iter 66, loss=1884.5342\n", |
| 232 | + "RFK iter 67, loss=1884.5195\n", |
| 233 | + "RFK iter 68, loss=1884.5063\n", |
| 234 | + "RFK iter 69, loss=1884.4939\n", |
| 235 | + "RFK iter 70, loss=1884.4825\n", |
| 236 | + "RFK iter 71, loss=1884.4719\n", |
| 237 | + "RFK iter 72, loss=1884.4620\n", |
| 238 | + "RFK iter 73, loss=1884.4526\n", |
| 239 | + "RFK iter 74, loss=1884.4442\n", |
| 240 | + "RFK iter 75, loss=1884.4366\n", |
| 241 | + "RFK iter 76, loss=1884.4297\n", |
| 242 | + "RFK iter 77, loss=1884.4235\n", |
| 243 | + "RFK iter 78, loss=1884.4174\n", |
| 244 | + "RFK iter 79, loss=1884.4120\n", |
| 245 | + "RFK iter 80, loss=1884.4073\n", |
| 246 | + "RFK iter 81, loss=1884.4026\n", |
| 247 | + "RFK iter 82, loss=1884.3982\n", |
| 248 | + "RFK iter 83, loss=1884.3947\n", |
| 249 | + "RFK iter 84, loss=1884.3909\n", |
| 250 | + "RFK iter 85, loss=1884.3877\n", |
| 251 | + "RFK iter 86, loss=1884.3842\n", |
| 252 | + "RFK iter 87, loss=1884.3817\n", |
| 253 | + "RFK iter 88, loss=1884.3794\n", |
| 254 | + "RFK iter 89, loss=1884.3767\n", |
| 255 | + "RFK iter 90, loss=1884.3749\n", |
| 256 | + "RFK iter 91, loss=1884.3727\n", |
| 257 | + "RFK iter 92, loss=1884.3707\n", |
| 258 | + "RFK iter 93, loss=1884.3694\n", |
| 259 | + "RFK iter 94, loss=1884.3676\n", |
| 260 | + "RFK iter 95, loss=1884.3666\n", |
| 261 | + "RFK iter 96, loss=1884.3644\n", |
| 262 | + "RFK iter 97, loss=1884.3638\n", |
| 263 | + "RFK iter 98, loss=1884.3623\n", |
| 264 | + "RFK iter 99, loss=1884.3612\n", |
| 265 | + "RFK iter 100, loss=1884.3602\n", |
| 266 | + "RFK iter 101, loss=1884.3591\n", |
| 267 | + "RFK iter 102, loss=1884.3583\n", |
| 268 | + "RFK iter 103, loss=1884.3574\n", |
| 269 | + "RFK iter 104, loss=1884.3569\n", |
| 270 | + "RFK iter 105, loss=1884.3561\n", |
| 271 | + "RFK iter 106, loss=1884.3547\n", |
| 272 | + "RFK iter 107, loss=1884.3544\n", |
| 273 | + "RFK iter 108, loss=1884.3538\n", |
| 274 | + "RFK iter 109, loss=1884.3536\n", |
| 275 | + "RFK iter 110, loss=1884.3523\n", |
| 276 | + "RFK iter 111, loss=1884.3518\n", |
| 277 | + "RFK iter 112, loss=1884.3511\n", |
| 278 | + "RFK iter 113, loss=1884.3511\n" |
| 279 | + ] |
| 280 | + } |
| 281 | + ], |
| 282 | + "source": [ |
| 283 | + "model = RiemannianFuzzyKMeans(n_clusters, \n", |
| 284 | + " manifold=P,\n", |
| 285 | + " random_state=seed, \n", |
| 286 | + " max_iter=1000,\n", |
| 287 | + " tol=tol,\n", |
| 288 | + " optimizer=opt,\n", |
| 289 | + " lr=lr,\n", |
| 290 | + " verbose=True)\n", |
| 291 | + "labels = model.fit_predict(X)" |
| 292 | + ] |
| 293 | + }, |
| 294 | + { |
| 295 | + "cell_type": "markdown", |
| 296 | + "metadata": {}, |
| 297 | + "source": [ |
| 298 | + "What if we don't use a manifold-based clustering method and instead apply standard KMeans? We'll compare the results to evaluate the performance difference." |
| 299 | + ] |
| 300 | + }, |
| 301 | + { |
| 302 | + "cell_type": "code", |
| 303 | + "execution_count": 8, |
| 304 | + "metadata": {}, |
| 305 | + "outputs": [], |
| 306 | + "source": [ |
| 307 | + "from sklearn.cluster import KMeans\n", |
| 308 | + "kmeans = KMeans(n_clusters=n_clusters, random_state=0)\n", |
| 309 | + "# Fit the data\n", |
| 310 | + "kmeans.fit(X)\n", |
| 311 | + "# Get the cluster labels from kmeans\n", |
| 312 | + "labels_km = kmeans.labels_" |
| 313 | + ] |
| 314 | + }, |
| 315 | + { |
| 316 | + "cell_type": "code", |
| 317 | + "execution_count": 9, |
| 318 | + "metadata": {}, |
| 319 | + "outputs": [ |
| 320 | + { |
| 321 | + "name": "stdout", |
| 322 | + "output_type": "stream", |
| 323 | + "text": [ |
| 324 | + "[[0.998 0.98869808 0.998 0.99811609 0.99801587 0.99822695\n", |
| 325 | + " 0.99363694]]\n", |
| 326 | + "[[0.44 0.07972898 0.444 0.30688818 0.33132184 0.40742235\n", |
| 327 | + " 0.03685404]]\n" |
| 328 | + ] |
| 329 | + } |
| 330 | + ], |
| 331 | + "source": [ |
| 332 | + "result = evaluate(y_true, labels).reshape(1, -1)\n", |
| 333 | + "result2 = evaluate(y_true, labels_km).reshape(1, -1)\n", |
| 334 | + "print(result)\n", |
| 335 | + "print(result2)" |
| 336 | + ] |
| 337 | + }, |
| 338 | + { |
| 339 | + "cell_type": "markdown", |
| 340 | + "metadata": {}, |
| 341 | + "source": [ |
| 342 | + "The performance of **Riemannian Fuzzy KMeans** seems to be much better than that of standard **KMeans**. Let's try adjusting some parameters to see if we can improve or better understand the results!\n" |
| 343 | + ] |
| 344 | + } |
| 345 | + ], |
| 346 | + "metadata": { |
| 347 | + "kernelspec": { |
| 348 | + "display_name": "RoE", |
| 349 | + "language": "python", |
| 350 | + "name": "python3" |
| 351 | + }, |
| 352 | + "language_info": { |
| 353 | + "codemirror_mode": { |
| 354 | + "name": "ipython", |
| 355 | + "version": 3 |
| 356 | + }, |
| 357 | + "file_extension": ".py", |
| 358 | + "mimetype": "text/x-python", |
| 359 | + "name": "python", |
| 360 | + "nbconvert_exporter": "python", |
| 361 | + "pygments_lexer": "ipython3", |
| 362 | + "version": "3.10.16" |
| 363 | + } |
| 364 | + }, |
| 365 | + "nbformat": 4, |
| 366 | + "nbformat_minor": 1 |
| 367 | +} |
0 commit comments