Skip to content

Commit 1c4d5c6

Browse files
committed
Update compiled() calls in notebook
1 parent 19a26af commit 1c4d5c6

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

examples/sparse_finch.ipynb

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@
187187
" D = sparse.asarray(np.array(D_sps, order=\"F\"))\n",
188188
" C = sparse.asarray(np.array(C_sps, order=\"F\"))\n",
189189
"\n",
190-
" @sparse.compiled(opt=\"default\")\n",
190+
" @sparse.compiled(opt=sparse.DefaultScheduler())\n",
191191
" def mttkrp_finch(B, D, C):\n",
192192
" return sparse.sum(B[:, :, :, None] * D[None, None, :, :] * C[None, :, None, :], axis=(1, 2))\n",
193193
"\n",
@@ -204,7 +204,7 @@
204204
" D = sparse.asarray(np.array(D_sps, order=\"F\"))\n",
205205
" C = sparse.asarray(np.array(C_sps, order=\"F\"))\n",
206206
"\n",
207-
" @sparse.compiled(opt=\"galley\")\n",
207+
" @sparse.compiled(opt=sparse.GalleyScheduler())\n",
208208
" def mttkrp_finch(B, D, C):\n",
209209
" return sparse.sum(B[:, :, :, None] * D[None, None, :, :] * C[None, :, None, :], axis=(1, 2))\n",
210210
"\n",
@@ -309,7 +309,7 @@
309309
" a = sparse.asarray(np.array(a_sps, order=\"F\"))\n",
310310
" b = sparse.asarray(np.array(b_sps, order=\"C\"))\n",
311311
"\n",
312-
" @sparse.compiled(opt=\"default\")\n",
312+
" @sparse.compiled(opt=sparse.DefaultScheduler())\n",
313313
" def sddmm_finch(s, a, b):\n",
314314
" return sparse.sum(\n",
315315
" s[:, :, None] * (a[:, None, :] * sparse.permute_dims(b, (1, 0))[None, :, :]),\n",
@@ -330,7 +330,7 @@
330330
" a = sparse.asarray(np.array(a_sps, order=\"F\"))\n",
331331
" b = sparse.asarray(np.array(b_sps, order=\"C\"))\n",
332332
"\n",
333-
" @sparse.compiled(opt=\"galley\")\n",
333+
" @sparse.compiled(opt=sparse.GalleyScheduler())\n",
334334
" def sddmm_finch(s, a, b):\n",
335335
" # return s * (a @ b)\n",
336336
" return sparse.sum(\n",
@@ -442,7 +442,7 @@
442442
"\n",
443443
" a = sparse.asarray(a_sps)\n",
444444
"\n",
445-
" @sparse.compiled(opt=\"default\")\n",
445+
" @sparse.compiled(opt=sparse.DefaultScheduler())\n",
446446
" def ct_finch(a):\n",
447447
" return sparse.sum(\n",
448448
" a[:, :, None] * a[:, None, :] * sparse.permute_dims(a, (1, 0))[None, :, :],\n",
@@ -460,7 +460,7 @@
460460
"\n",
461461
" a = sparse.asarray(a_sps)\n",
462462
"\n",
463-
" @sparse.compiled(opt=\"galley\")\n",
463+
" @sparse.compiled(opt=sparse.GalleyScheduler())\n",
464464
" def ct_finch(a):\n",
465465
" return sparse.sum(\n",
466466
" a[:, :, None] * a[:, None, :] * sparse.permute_dims(a, (1, 0))[None, :, :],\n",

sparse/finch_backend/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
raise ImportError("Finch not installed. Run `pip install sparse[finch]` to enable Finch backend") from e
55

66
from finch import (
7+
DefaultScheduler,
8+
GalleyScheduler,
79
SparseArray,
810
abs,
911
acos,
@@ -97,6 +99,7 @@
9799
remainder,
98100
reshape,
99101
round,
102+
set_optimizer,
100103
sign,
101104
sin,
102105
sinh,
@@ -119,6 +122,8 @@
119122
)
120123

121124
__all__ = [
125+
"DefaultScheduler",
126+
"GalleyScheduler",
122127
"SparseArray",
123128
"abs",
124129
"acos",
@@ -231,4 +236,5 @@
231236
"empty_like",
232237
"arange",
233238
"linspace",
239+
"set_optimizer",
234240
]

0 commit comments

Comments
 (0)