Skip to content

Commit bf8236e

Browse files
authored
Merge pull request #827 from pydata/update-finch-notebook
Update `sparse_finch.ipynb`
2 parents 68eb00f + 77c155d commit bf8236e

File tree

4 files changed

+45
-47
lines changed

4 files changed

+45
-47
lines changed

ci/environment.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,6 @@ dependencies:
1313
- pytest-cov
1414
- pytest-xdist
1515
- pip:
16-
- finch-tensor>=0.2.1
16+
- finch-tensor>=0.2.2
1717
- finch-mlir>=0.0.2
1818
- pytest-codspeed

examples/sparse_finch.ipynb

Lines changed: 42 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@
8282
"X = sparse.asarray(X, format=\"csc\")\n",
8383
"X_lazy = sparse.lazy(X)\n",
8484
"\n",
85-
"X_X = sparse.compute(sparse.permute_dims(X_lazy, (1, 0)) @ X_lazy, verbose=True)\n",
85+
"X_X = sparse.compute(sparse.permute_dims(X_lazy, (1, 0)) @ X_lazy)\n",
8686
"\n",
8787
"X_X = sparse.asarray(X_X, format=\"csc\") # move back from dense to CSC format\n",
8888
"\n",
@@ -171,11 +171,8 @@
171171
"finch_galley_times = []\n",
172172
"\n",
173173
"for config in configs:\n",
174-
" B_sps = sparse.random(\n",
175-
" (config[\"I_\"], config[\"K_\"], config[\"L_\"]),\n",
176-
" density=config[\"DENSITY\"],\n",
177-
" random_state=rng,\n",
178-
" )\n",
174+
" B_shape = (config[\"I_\"], config[\"K_\"], config[\"L_\"])\n",
175+
" B_sps = sparse.random(B_shape, density=config[\"DENSITY\"], random_state=rng)\n",
179176
" D_sps = rng.random((config[\"L_\"], config[\"J_\"]))\n",
180177
" C_sps = rng.random((config[\"K_\"], config[\"J_\"]))\n",
181178
"\n",
@@ -204,14 +201,14 @@
204201
" D = sparse.asarray(np.array(D_sps, order=\"F\"))\n",
205202
" C = sparse.asarray(np.array(C_sps, order=\"F\"))\n",
206203
"\n",
207-
" @sparse.compiled(opt=sparse.GalleyScheduler())\n",
208-
" def mttkrp_finch(B, D, C):\n",
204+
" @sparse.compiled(opt=sparse.GalleyScheduler(), tag=sum(B_shape))\n",
205+
" def mttkrp_finch_galley(B, D, C):\n",
209206
" return sparse.sum(B[:, :, :, None] * D[None, None, :, :] * C[None, :, None, :], axis=(1, 2))\n",
210207
"\n",
211208
" # Compile\n",
212-
" result_finch_galley = mttkrp_finch(B, D, C)\n",
209+
" result_finch_galley = mttkrp_finch_galley(B, D, C)\n",
213210
" # Benchmark\n",
214-
" time_finch_galley = benchmark(mttkrp_finch, info=\"Finch\", args=[B, D, C])\n",
211+
" time_finch_galley = benchmark(mttkrp_finch_galley, info=\"Finch Galley\", args=[B, D, C])\n",
215212
"\n",
216213
" # ======= Numba =======\n",
217214
" os.environ[sparse._ENV_VAR_NAME] = \"Numba\"\n",
@@ -276,11 +273,12 @@
276273
"configs = [\n",
277274
" {\"LEN\": 5000, \"DENSITY\": 0.00001},\n",
278275
" {\"LEN\": 10000, \"DENSITY\": 0.00001},\n",
276+
" {\"LEN\": 15000, \"DENSITY\": 0.00001},\n",
279277
" {\"LEN\": 20000, \"DENSITY\": 0.00001},\n",
280278
" {\"LEN\": 25000, \"DENSITY\": 0.00001},\n",
281279
" {\"LEN\": 30000, \"DENSITY\": 0.00001},\n",
282280
"]\n",
283-
"size_n = [5000, 10000, 20000, 25000, 30000]\n",
281+
"size_n = [5000, 10000, 15000, 20000, 25000, 30000]\n",
284282
"\n",
285283
"if CI_MODE:\n",
286284
" configs = configs[:1]\n",
@@ -306,15 +304,12 @@
306304
" importlib.reload(sparse)\n",
307305
"\n",
308306
" s = sparse.asarray(s_sps)\n",
309-
" a = sparse.asarray(np.array(a_sps, order=\"F\"))\n",
310-
" b = sparse.asarray(np.array(b_sps, order=\"C\"))\n",
307+
" a = sparse.asarray(a_sps)\n",
308+
" b = sparse.asarray(b_sps)\n",
311309
"\n",
312310
" @sparse.compiled(opt=sparse.DefaultScheduler())\n",
313311
" def sddmm_finch(s, a, b):\n",
314-
" return sparse.sum(\n",
315-
" s[:, :, None] * (a[:, None, :] * sparse.permute_dims(b, (1, 0))[None, :, :]),\n",
316-
" axis=-1,\n",
317-
" )\n",
312+
" return s * (a @ b)\n",
318313
"\n",
319314
" # Compile\n",
320315
" result_finch = sddmm_finch(s, a, b)\n",
@@ -327,21 +322,17 @@
327322
" importlib.reload(sparse)\n",
328323
"\n",
329324
" s = sparse.asarray(s_sps)\n",
330-
" a = sparse.asarray(np.array(a_sps, order=\"F\"))\n",
331-
" b = sparse.asarray(np.array(b_sps, order=\"C\"))\n",
325+
" a = sparse.asarray(a_sps)\n",
326+
" b = sparse.asarray(b_sps)\n",
332327
"\n",
333-
" @sparse.compiled(opt=sparse.GalleyScheduler())\n",
334-
" def sddmm_finch(s, a, b):\n",
335-
" # return s * (a @ b)\n",
336-
" return sparse.sum(\n",
337-
" s[:, :, None] * (a[:, None, :] * sparse.permute_dims(b, (1, 0))[None, :, :]),\n",
338-
" axis=-1,\n",
339-
" )\n",
328+
" @sparse.compiled(opt=sparse.GalleyScheduler(), tag=LEN)\n",
329+
" def sddmm_finch_galley(s, a, b):\n",
330+
" return s * (a @ b)\n",
340331
"\n",
341332
" # Compile\n",
342-
" result_finch_galley = sddmm_finch(s, a, b)\n",
333+
" result_finch_galley = sddmm_finch_galley(s, a, b)\n",
343334
" # Benchmark\n",
344-
" time_finch_galley = benchmark(sddmm_finch, info=\"Finch\", args=[s, a, b])\n",
335+
" time_finch_galley = benchmark(sddmm_finch_galley, info=\"Finch Galley\", args=[s, a, b])\n",
345336
"\n",
346337
" # ======= Numba =======\n",
347338
" print(\"numba\")\n",
@@ -402,6 +393,13 @@
402393
"plt.show()"
403394
]
404395
},
396+
{
397+
"cell_type": "markdown",
398+
"metadata": {},
399+
"source": [
400+
"## Counting Triangles"
401+
]
402+
},
405403
{
406404
"cell_type": "code",
407405
"execution_count": null,
@@ -411,13 +409,17 @@
411409
"print(\"Counting Triangles Example:\\n\")\n",
412410
"\n",
413411
"configs = [\n",
414-
" {\"LEN\": 1000, \"DENSITY\": 0.1},\n",
415-
" {\"LEN\": 2000, \"DENSITY\": 0.1},\n",
416-
" {\"LEN\": 3000, \"DENSITY\": 0.1},\n",
417-
" {\"LEN\": 4000, \"DENSITY\": 0.1},\n",
418-
" {\"LEN\": 5000, \"DENSITY\": 0.1},\n",
412+
" {\"LEN\": 10000, \"DENSITY\": 0.001},\n",
413+
" {\"LEN\": 15000, \"DENSITY\": 0.001},\n",
414+
" {\"LEN\": 20000, \"DENSITY\": 0.001},\n",
415+
" {\"LEN\": 25000, \"DENSITY\": 0.001},\n",
416+
" {\"LEN\": 30000, \"DENSITY\": 0.001},\n",
417+
" {\"LEN\": 35000, \"DENSITY\": 0.001},\n",
418+
" {\"LEN\": 40000, \"DENSITY\": 0.001},\n",
419+
" {\"LEN\": 45000, \"DENSITY\": 0.001},\n",
420+
" {\"LEN\": 50000, \"DENSITY\": 0.001},\n",
419421
"]\n",
420-
"size_n = [1000, 2000, 3000, 4000, 5000]\n",
422+
"size_n = [10000, 15000, 20000, 25000, 30000, 35000, 40000, 45000, 50000]\n",
421423
"\n",
422424
"if CI_MODE:\n",
423425
" configs = configs[:1]\n",
@@ -444,9 +446,7 @@
444446
"\n",
445447
" @sparse.compiled(opt=sparse.DefaultScheduler())\n",
446448
" def ct_finch(a):\n",
447-
" return sparse.sum(\n",
448-
" a[:, :, None] * a[:, None, :] * sparse.permute_dims(a, (1, 0))[None, :, :],\n",
449-
" ) / sparse.asarray(6)\n",
449+
" return sparse.sum(a @ a * a) / sparse.asarray(6)\n",
450450
"\n",
451451
" # Compile\n",
452452
" result_finch = ct_finch(a)\n",
@@ -460,16 +460,14 @@
460460
"\n",
461461
" a = sparse.asarray(a_sps)\n",
462462
"\n",
463-
" @sparse.compiled(opt=sparse.GalleyScheduler())\n",
464-
" def ct_finch(a):\n",
465-
" return sparse.sum(\n",
466-
" a[:, :, None] * a[:, None, :] * sparse.permute_dims(a, (1, 0))[None, :, :],\n",
467-
" ) / sparse.asarray(6)\n",
463+
" @sparse.compiled(opt=sparse.GalleyScheduler(), tag=LEN)\n",
464+
" def ct_finch_galley(a):\n",
465+
" return sparse.sum(a @ a * a) / sparse.asarray(6)\n",
468466
"\n",
469467
" # Compile\n",
470-
" result_finch_galley = ct_finch(a)\n",
468+
" result_finch_galley = ct_finch_galley(a)\n",
471469
" # Benchmark\n",
472-
" time_finch_galley = benchmark(ct_finch, info=\"Finch\", args=[a])\n",
470+
" time_finch_galley = benchmark(ct_finch_galley, info=\"Finch Galley\", args=[a])\n",
473471
"\n",
474472
" # ======= SciPy =======\n",
475473
" print(\"scipy\")\n",

pixi.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ precompile = "python -c 'import finch'"
4949

5050
[feature.finch.pypi-dependencies]
5151
scipy = ">=1.13"
52-
finch-tensor = ">=0.2.1"
52+
finch-tensor = ">=0.2.2"
5353

5454
[feature.finch.activation.env]
5555
SPARSE_BACKEND = "Finch"

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ tests = [
5454
tox = ["sparse[tests]", "tox"]
5555
notebooks = ["sparse[tests]", "nbmake", "matplotlib"]
5656
all = ["sparse[docs,tox,notebooks]", "matrepr"]
57-
finch = ["finch-tensor>=0.2.1"]
57+
finch = ["finch-tensor>=0.2.2"]
5858

5959
[project.urls]
6060
Documentation = "https://sparse.pydata.org/"

0 commit comments

Comments
 (0)