|
187 | 187 | " D = sparse.asarray(np.array(D_sps, order=\"F\"))\n",
|
188 | 188 | " C = sparse.asarray(np.array(C_sps, order=\"F\"))\n",
|
189 | 189 | "\n",
|
190 |
| - " @sparse.compiled(opt=\"default\")\n", |
| 190 | + " @sparse.compiled(opt=sparse.DefaultScheduler())\n", |
191 | 191 | " def mttkrp_finch(B, D, C):\n",
|
192 | 192 | " return sparse.sum(B[:, :, :, None] * D[None, None, :, :] * C[None, :, None, :], axis=(1, 2))\n",
|
193 | 193 | "\n",
|
|
204 | 204 | " D = sparse.asarray(np.array(D_sps, order=\"F\"))\n",
|
205 | 205 | " C = sparse.asarray(np.array(C_sps, order=\"F\"))\n",
|
206 | 206 | "\n",
|
207 |
| - " @sparse.compiled(opt=\"galley\")\n", |
| 207 | + " @sparse.compiled(opt=sparse.GalleyScheduler())\n", |
208 | 208 | " def mttkrp_finch(B, D, C):\n",
|
209 | 209 | " return sparse.sum(B[:, :, :, None] * D[None, None, :, :] * C[None, :, None, :], axis=(1, 2))\n",
|
210 | 210 | "\n",
|
|
309 | 309 | " a = sparse.asarray(np.array(a_sps, order=\"F\"))\n",
|
310 | 310 | " b = sparse.asarray(np.array(b_sps, order=\"C\"))\n",
|
311 | 311 | "\n",
|
312 |
| - " @sparse.compiled(opt=\"default\")\n", |
| 312 | + " @sparse.compiled(opt=sparse.DefaultScheduler())\n", |
313 | 313 | " def sddmm_finch(s, a, b):\n",
|
314 | 314 | " return sparse.sum(\n",
|
315 | 315 | " s[:, :, None] * (a[:, None, :] * sparse.permute_dims(b, (1, 0))[None, :, :]),\n",
|
|
330 | 330 | " a = sparse.asarray(np.array(a_sps, order=\"F\"))\n",
|
331 | 331 | " b = sparse.asarray(np.array(b_sps, order=\"C\"))\n",
|
332 | 332 | "\n",
|
333 |
| - " @sparse.compiled(opt=\"galley\")\n", |
| 333 | + " @sparse.compiled(opt=sparse.GalleyScheduler())\n", |
334 | 334 | " def sddmm_finch(s, a, b):\n",
|
335 | 335 | " # return s * (a @ b)\n",
|
336 | 336 | " return sparse.sum(\n",
|
|
442 | 442 | "\n",
|
443 | 443 | " a = sparse.asarray(a_sps)\n",
|
444 | 444 | "\n",
|
445 |
| - " @sparse.compiled(opt=\"default\")\n", |
| 445 | + " @sparse.compiled(opt=sparse.DefaultScheduler())\n", |
446 | 446 | " def ct_finch(a):\n",
|
447 | 447 | " return sparse.sum(\n",
|
448 | 448 | " a[:, :, None] * a[:, None, :] * sparse.permute_dims(a, (1, 0))[None, :, :],\n",
|
|
460 | 460 | "\n",
|
461 | 461 | " a = sparse.asarray(a_sps)\n",
|
462 | 462 | "\n",
|
463 |
| - " @sparse.compiled(opt=\"galley\")\n", |
| 463 | + " @sparse.compiled(opt=sparse.GalleyScheduler())\n", |
464 | 464 | " def ct_finch(a):\n",
|
465 | 465 | " return sparse.sum(\n",
|
466 | 466 | " a[:, :, None] * a[:, None, :] * sparse.permute_dims(a, (1, 0))[None, :, :],\n",
|
|
0 commit comments