Skip to content

Commit f530bd9

Browse files
committed
Update sparse_finch notebook
1 parent 128a567 commit f530bd9

File tree

1 file changed

+209
-23
lines changed

1 file changed

+209
-23
lines changed

examples/sparse_finch.ipynb

Lines changed: 209 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
"import sparse\n",
4040
"\n",
4141
"import matplotlib.pyplot as plt\n",
42+
"import networkx as nx\n",
4243
"\n",
4344
"import numpy as np\n",
4445
"import scipy.sparse as sps\n",
@@ -105,7 +106,7 @@
105106
"metadata": {},
106107
"outputs": [],
107108
"source": [
108-
"ITERS = 3\n",
109+
"ITERS = 1\n",
109110
"rng = np.random.default_rng(0)"
110111
]
111112
},
@@ -134,6 +135,13 @@
134135
" return elapsed / ITERS"
135136
]
136137
},
138+
{
139+
"cell_type": "markdown",
140+
"metadata": {},
141+
"source": [
142+
"## MTTKRP"
143+
]
144+
},
137145
{
138146
"cell_type": "code",
139147
"execution_count": null,
@@ -146,26 +154,30 @@
146154
"importlib.reload(sparse)\n",
147155
"\n",
148156
"configs = [\n",
149-
" {\"I_\": 100, \"J_\": 25, \"K_\": 10, \"L_\": 10, \"DENSITY\": 0.001},\n",
150157
" {\"I_\": 100, \"J_\": 25, \"K_\": 100, \"L_\": 10, \"DENSITY\": 0.001},\n",
151158
" {\"I_\": 100, \"J_\": 25, \"K_\": 100, \"L_\": 100, \"DENSITY\": 0.001},\n",
152159
" {\"I_\": 1000, \"J_\": 25, \"K_\": 100, \"L_\": 100, \"DENSITY\": 0.001},\n",
153160
" {\"I_\": 1000, \"J_\": 25, \"K_\": 1000, \"L_\": 100, \"DENSITY\": 0.001},\n",
154161
" {\"I_\": 1000, \"J_\": 25, \"K_\": 1000, \"L_\": 1000, \"DENSITY\": 0.001},\n",
155162
"]\n",
156-
"nonzeros = [10000, 100_000, 1_000_000, 10_000_000, 100_000_000, 1_000_000_000]\n",
163+
"nonzeros = [100_000, 1_000_000, 10_000_000, 100_000_000, 1_000_000_000]\n",
157164
"\n",
158165
"if CI_MODE:\n",
159166
" configs = configs[:1]\n",
160167
" nonzeros = nonzeros[:1]\n",
161168
"\n",
162169
"finch_times = []\n",
163170
"numba_times = []\n",
171+
"finch_galley_times = []\n",
164172
"\n",
165173
"for config in configs:\n",
166-
" B_sps = sparse.random((config[\"I_\"], config[\"K_\"], config[\"L_\"]), density=config[\"DENSITY\"], random_state=rng) * 10\n",
167-
" D_sps = rng.random((config[\"L_\"], config[\"J_\"])) * 10\n",
168-
" C_sps = rng.random((config[\"K_\"], config[\"J_\"])) * 10\n",
174+
" B_sps = sparse.random(\n",
175+
" (config[\"I_\"], config[\"K_\"], config[\"L_\"]),\n",
176+
" density=config[\"DENSITY\"],\n",
177+
" random_state=rng,\n",
178+
" )\n",
179+
" D_sps = rng.random((config[\"L_\"], config[\"J_\"]))\n",
180+
" C_sps = rng.random((config[\"K_\"], config[\"J_\"]))\n",
169181
"\n",
170182
" # ======= Finch =======\n",
171183
" os.environ[sparse._ENV_VAR_NAME] = \"Finch\"\n",
@@ -175,7 +187,7 @@
175187
" D = sparse.asarray(np.array(D_sps, order=\"F\"))\n",
176188
" C = sparse.asarray(np.array(C_sps, order=\"F\"))\n",
177189
"\n",
178-
" @sparse.compiled\n",
190+
" @sparse.compiled(opt=\"default\")\n",
179191
" def mttkrp_finch(B, D, C):\n",
180192
" return sparse.sum(B[:, :, :, None] * D[None, None, :, :] * C[None, :, None, :], axis=(1, 2))\n",
181193
"\n",
@@ -184,6 +196,23 @@
184196
" # Benchmark\n",
185197
" time_finch = benchmark(mttkrp_finch, info=\"Finch\", args=[B, D, C])\n",
186198
"\n",
199+
" # ======= Finch Galley =======\n",
200+
" os.environ[sparse._ENV_VAR_NAME] = \"Finch\"\n",
201+
" importlib.reload(sparse)\n",
202+
"\n",
203+
" B = sparse.asarray(B_sps.todense(), format=\"csf\")\n",
204+
" D = sparse.asarray(np.array(D_sps, order=\"F\"))\n",
205+
" C = sparse.asarray(np.array(C_sps, order=\"F\"))\n",
206+
"\n",
207+
" @sparse.compiled(opt=\"galley\")\n",
208+
" def mttkrp_finch(B, D, C):\n",
209+
" return sparse.sum(B[:, :, :, None] * D[None, None, :, :] * C[None, :, None, :], axis=(1, 2))\n",
210+
"\n",
211+
" # Compile\n",
212+
" result_finch_galley = mttkrp_finch(B, D, C)\n",
213+
" # Benchmark\n",
214+
" time_finch_galley = benchmark(mttkrp_finch, info=\"Finch\", args=[B, D, C])\n",
215+
"\n",
187216
" # ======= Numba =======\n",
188217
" os.environ[sparse._ENV_VAR_NAME] = \"Numba\"\n",
189218
" importlib.reload(sparse)\n",
@@ -201,8 +230,10 @@
201230
" time_numba = benchmark(mttkrp_numba, info=\"Numba\", args=[B, D, C])\n",
202231
"\n",
203232
" np.testing.assert_allclose(result_finch.todense(), result_numba.todense())\n",
233+
"\n",
204234
" finch_times.append(time_finch)\n",
205-
" numba_times.append(time_numba)"
235+
" numba_times.append(time_numba)\n",
236+
" finch_galley_times.append(time_finch_galley)"
206237
]
207238
},
208239
{
@@ -215,6 +246,7 @@
215246
"\n",
216247
"ax.plot(nonzeros, finch_times, \"o-\", label=\"Finch\")\n",
217248
"ax.plot(nonzeros, numba_times, \"o-\", label=\"Numba\")\n",
249+
"ax.plot(nonzeros, finch_galley_times, \"o-\", label=\"Finch - Galley\")\n",
218250
"ax.grid(True)\n",
219251
"ax.set_xlabel(\"no. of elements\")\n",
220252
"ax.set_ylabel(\"time (sec)\")\n",
@@ -226,6 +258,13 @@
226258
"plt.show()"
227259
]
228260
},
261+
{
262+
"cell_type": "markdown",
263+
"metadata": {},
264+
"source": [
265+
"## SDDMM"
266+
]
267+
},
229268
{
230269
"cell_type": "code",
231270
"execution_count": null,
@@ -235,15 +274,13 @@
235274
"print(\"SDDMM Example:\\n\")\n",
236275
"\n",
237276
"configs = [\n",
238-
" {\"LEN\": 10, \"DENSITY\": 0.1},\n",
239-
" {\"LEN\": 50, \"DENSITY\": 0.05},\n",
240-
" {\"LEN\": 100, \"DENSITY\": 0.01},\n",
241-
" {\"LEN\": 500, \"DENSITY\": 0.005},\n",
242-
" {\"LEN\": 1000, \"DENSITY\": 0.001},\n",
243-
" {\"LEN\": 5000, \"DENSITY\": 0.00005},\n",
277+
" {\"LEN\": 5000, \"DENSITY\": 0.00001},\n",
244278
" {\"LEN\": 10000, \"DENSITY\": 0.00001},\n",
279+
" {\"LEN\": 20000, \"DENSITY\": 0.00001},\n",
280+
" {\"LEN\": 25000, \"DENSITY\": 0.00001},\n",
281+
" {\"LEN\": 30000, \"DENSITY\": 0.00001},\n",
245282
"]\n",
246-
"size_n = [10, 50, 100, 500, 1000, 5000, 10000]\n",
283+
"size_n = [5000, 10000, 20000, 25000, 30000]\n",
247284
"\n",
248285
"if CI_MODE:\n",
249286
" configs = configs[:1]\n",
@@ -252,25 +289,27 @@
252289
"finch_times = []\n",
253290
"numba_times = []\n",
254291
"scipy_times = []\n",
292+
"finch_galley_times = []\n",
255293
"\n",
256294
"for config in configs:\n",
257295
" LEN = config[\"LEN\"]\n",
258296
" DENSITY = config[\"DENSITY\"]\n",
259297
"\n",
260-
" a_sps = rng.random((LEN, LEN)) * 10\n",
261-
" b_sps = rng.random((LEN, LEN)) * 10\n",
262-
" s_sps = sps.random(LEN, LEN, format=\"coo\", density=DENSITY, random_state=rng) * 10\n",
298+
" a_sps = rng.random((LEN, LEN))\n",
299+
" b_sps = rng.random((LEN, LEN))\n",
300+
" s_sps = sps.random(LEN, LEN, format=\"coo\", density=DENSITY, random_state=rng)\n",
263301
" s_sps.sum_duplicates()\n",
264302
"\n",
265303
" # ======= Finch =======\n",
304+
" print(\"finch\")\n",
266305
" os.environ[sparse._ENV_VAR_NAME] = \"Finch\"\n",
267306
" importlib.reload(sparse)\n",
268307
"\n",
269308
" s = sparse.asarray(s_sps)\n",
270309
" a = sparse.asarray(np.array(a_sps, order=\"F\"))\n",
271310
" b = sparse.asarray(np.array(b_sps, order=\"C\"))\n",
272311
"\n",
273-
" @sparse.compiled\n",
312+
" @sparse.compiled(opt=\"default\")\n",
274313
" def sddmm_finch(s, a, b):\n",
275314
" return sparse.sum(\n",
276315
" s[:, :, None] * (a[:, None, :] * sparse.permute_dims(b, (1, 0))[None, :, :]),\n",
@@ -282,7 +321,30 @@
282321
" # Benchmark\n",
283322
" time_finch = benchmark(sddmm_finch, info=\"Finch\", args=[s, a, b])\n",
284323
"\n",
324+
" # ======= Finch Galley =======\n",
325+
" print(\"finch galley\")\n",
326+
" os.environ[sparse._ENV_VAR_NAME] = \"Finch\"\n",
327+
" importlib.reload(sparse)\n",
328+
"\n",
329+
" s = sparse.asarray(s_sps)\n",
330+
" a = sparse.asarray(np.array(a_sps, order=\"F\"))\n",
331+
" b = sparse.asarray(np.array(b_sps, order=\"C\"))\n",
332+
"\n",
333+
" @sparse.compiled(opt=\"galley\")\n",
334+
" def sddmm_finch(s, a, b):\n",
335+
" # return s * (a @ b)\n",
336+
" return sparse.sum(\n",
337+
" s[:, :, None] * (a[:, None, :] * sparse.permute_dims(b, (1, 0))[None, :, :]),\n",
338+
" axis=-1,\n",
339+
" )\n",
340+
"\n",
341+
" # Compile\n",
342+
" result_finch_galley = sddmm_finch(s, a, b)\n",
343+
" # Benchmark\n",
344+
" time_finch_galley = benchmark(sddmm_finch, info=\"Finch\", args=[s, a, b])\n",
345+
"\n",
285346
" # ======= Numba =======\n",
347+
" print(\"numba\")\n",
286348
" os.environ[sparse._ENV_VAR_NAME] = \"Numba\"\n",
287349
" importlib.reload(sparse)\n",
288350
"\n",
@@ -299,6 +361,8 @@
299361
" time_numba = benchmark(sddmm_numba, info=\"Numba\", args=[s, a, b])\n",
300362
"\n",
301363
" # ======= SciPy =======\n",
364+
" print(\"scipy\")\n",
365+
"\n",
302366
" def sddmm_scipy(s, a, b):\n",
303367
" return s.multiply(a @ b)\n",
304368
"\n",
@@ -312,7 +376,8 @@
312376
"\n",
313377
" finch_times.append(time_finch)\n",
314378
" numba_times.append(time_numba)\n",
315-
" scipy_times.append(time_scipy)"
379+
" scipy_times.append(time_scipy)\n",
380+
" finch_galley_times.append(time_finch_galley)"
316381
]
317382
},
318383
{
@@ -326,13 +391,134 @@
326391
"ax.plot(size_n, finch_times, \"o-\", label=\"Finch\")\n",
327392
"ax.plot(size_n, numba_times, \"o-\", label=\"Numba\")\n",
328393
"ax.plot(size_n, scipy_times, \"o-\", label=\"SciPy\")\n",
394+
"ax.plot(size_n, finch_galley_times, \"o-\", label=\"Finch Galley\")\n",
329395
"\n",
330396
"ax.grid(True)\n",
331397
"ax.set_xlabel(\"size N\")\n",
332398
"ax.set_ylabel(\"time (sec)\")\n",
333399
"ax.set_title(\"SDDMM\")\n",
334-
"ax.set_xscale(\"log\")\n",
335-
"# ax.set_yscale('log')\n",
400+
"# ax.set_xscale(\"log\")\n",
401+
"# ax.set_yscale(\"log\")\n",
402+
"ax.legend(loc=\"best\", numpoints=1)\n",
403+
"\n",
404+
"plt.show()"
405+
]
406+
},
407+
{
408+
"cell_type": "code",
409+
"execution_count": null,
410+
"metadata": {},
411+
"outputs": [],
412+
"source": [
413+
"print(\"Counting Triangles Example:\\n\")\n",
414+
"\n",
415+
"configs = [\n",
416+
" {\"LEN\": 1000, \"DENSITY\": 0.1},\n",
417+
" {\"LEN\": 2000, \"DENSITY\": 0.1},\n",
418+
" {\"LEN\": 3000, \"DENSITY\": 0.1},\n",
419+
" {\"LEN\": 4000, \"DENSITY\": 0.1},\n",
420+
" {\"LEN\": 5000, \"DENSITY\": 0.1},\n",
421+
"]\n",
422+
"size_n = [1000, 2000, 3000, 4000, 5000]\n",
423+
"\n",
424+
"if CI_MODE:\n",
425+
" configs = configs[:1]\n",
426+
" size_n = size_n[:1]\n",
427+
"\n",
428+
"finch_times = []\n",
429+
"finch_galley_times = []\n",
430+
"networkx_times = []\n",
431+
"scipy_times = []\n",
432+
"\n",
433+
"for config in configs:\n",
434+
" LEN = config[\"LEN\"]\n",
435+
" DENSITY = config[\"DENSITY\"]\n",
436+
"\n",
437+
" G = nx.gnp_random_graph(n=LEN, p=DENSITY)\n",
438+
" a_sps = nx.to_scipy_sparse_array(G)\n",
439+
"\n",
440+
" # ======= Finch =======\n",
441+
" print(\"finch\")\n",
442+
" os.environ[sparse._ENV_VAR_NAME] = \"Finch\"\n",
443+
" importlib.reload(sparse)\n",
444+
"\n",
445+
" a = sparse.asarray(a_sps)\n",
446+
"\n",
447+
" @sparse.compiled(opt=\"default\")\n",
448+
" def ct_finch(a):\n",
449+
" return sparse.sum(\n",
450+
" a[:, :, None] * a[:, None, :] * sparse.permute_dims(a, (1, 0))[None, :, :],\n",
451+
" ) / sparse.asarray(6)\n",
452+
"\n",
453+
" # Compile\n",
454+
" result_finch = ct_finch(a)\n",
455+
" # Benchmark\n",
456+
" time_finch = benchmark(ct_finch, info=\"Finch\", args=[a])\n",
457+
"\n",
458+
" # ======= Finch Galley =======\n",
459+
" print(\"finch galley\")\n",
460+
" os.environ[sparse._ENV_VAR_NAME] = \"Finch\"\n",
461+
" importlib.reload(sparse)\n",
462+
"\n",
463+
" a = sparse.asarray(a_sps)\n",
464+
"\n",
465+
" @sparse.compiled(opt=\"galley\")\n",
466+
" def ct_finch(a):\n",
467+
" return sparse.sum(\n",
468+
" a[:, :, None] * a[:, None, :] * sparse.permute_dims(a, (1, 0))[None, :, :],\n",
469+
" ) / sparse.asarray(6)\n",
470+
"\n",
471+
" # Compile\n",
472+
" result_finch_galley = ct_finch(a)\n",
473+
" # Benchmark\n",
474+
" time_finch_galley = benchmark(ct_finch, info=\"Finch\", args=[a])\n",
475+
"\n",
476+
" # ======= SciPy =======\n",
477+
" print(\"scipy\")\n",
478+
"\n",
479+
" def ct_scipy(a):\n",
480+
" return (a @ a * a).sum() / 6\n",
481+
"\n",
482+
" a = a_sps\n",
483+
"\n",
484+
" # Benchmark\n",
485+
" time_scipy = benchmark(ct_scipy, info=\"SciPy\", args=[a])\n",
486+
"\n",
487+
" # ======= NetworkX =======\n",
488+
" print(\"networkx\")\n",
489+
"\n",
490+
" def ct_networkx(a):\n",
491+
" return sum(nx.triangles(a).values()) / 3\n",
492+
"\n",
493+
" a = G\n",
494+
"\n",
495+
" time_networkx = benchmark(ct_networkx, info=\"SciPy\", args=[a])\n",
496+
"\n",
497+
" finch_times.append(time_finch)\n",
498+
" finch_galley_times.append(time_finch_galley)\n",
499+
" networkx_times.append(time_networkx)\n",
500+
" scipy_times.append(time_scipy)"
501+
]
502+
},
503+
{
504+
"cell_type": "code",
505+
"execution_count": null,
506+
"metadata": {},
507+
"outputs": [],
508+
"source": [
509+
"fig, ax = plt.subplots(nrows=1, ncols=1)\n",
510+
"\n",
511+
"ax.plot(size_n, finch_times, \"o-\", label=\"Finch\")\n",
512+
"ax.plot(size_n, networkx_times, \"o-\", label=\"NetworkX\")\n",
513+
"ax.plot(size_n, scipy_times, \"o-\", label=\"SciPy\")\n",
514+
"ax.plot(size_n, finch_galley_times, \"o-\", label=\"Finch Galley\")\n",
515+
"\n",
516+
"ax.grid(True)\n",
517+
"ax.set_xlabel(\"size N\")\n",
518+
"ax.set_ylabel(\"time (sec)\")\n",
519+
"ax.set_title(\"Counting Triangles\")\n",
520+
"# ax.set_xscale(\"log\")\n",
521+
"# ax.set_yscale(\"log\")\n",
336522
"ax.legend(loc=\"best\", numpoints=1)\n",
337523
"\n",
338524
"plt.show()"
@@ -355,7 +541,7 @@
355541
"name": "python",
356542
"nbconvert_exporter": "python",
357543
"pygments_lexer": "ipython3",
358-
"version": "3.12.2"
544+
"version": "3.10.14"
359545
}
360546
},
361547
"nbformat": 4,

0 commit comments

Comments
 (0)