Skip to content

Commit 5ed2868

Browse files
majosmCopilot
andauthored
Add sparse matrix interface (#349)
* add sparse matrix interface * add tentative numpy actx support * add tests * fix some typing * put actx-specific matrix multiplication implementations inside sparse_matmul instead of using derived matrix classes * remove FIXMEs about tags * add optional scipy dependency for numpy sparse * update basedpyright baseline for loopy stuff * fix make_csr_matrix docstring * omit dtype specification for floating point CSR matmul args Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> * mention arraycontext[sparse] in exception message if trying to use numpy array context without scipy * don't test CSR on numpy array context if scipy isn't installed * install scipy for pytest CI * install scipy for pytest CI (gitlab) * don't use global address space for reduction bound temps --------- Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
1 parent 8c00ed8 commit 5ed2868

File tree

10 files changed

+435
-7
lines changed

10 files changed

+435
-7
lines changed

.basedpyright/baseline.json

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -954,6 +954,54 @@
954954
"endColumn": 29,
955955
"lineCount": 1
956956
}
957+
},
958+
{
959+
"code": "reportUnknownMemberType",
960+
"range": {
961+
"startColumn": 16,
962+
"endColumn": 28,
963+
"lineCount": 1
964+
}
965+
},
966+
{
967+
"code": "reportUnknownMemberType",
968+
"range": {
969+
"startColumn": 16,
970+
"endColumn": 28,
971+
"lineCount": 1
972+
}
973+
},
974+
{
975+
"code": "reportUnknownMemberType",
976+
"range": {
977+
"startColumn": 16,
978+
"endColumn": 28,
979+
"lineCount": 1
980+
}
981+
},
982+
{
983+
"code": "reportUnknownMemberType",
984+
"range": {
985+
"startColumn": 16,
986+
"endColumn": 28,
987+
"lineCount": 1
988+
}
989+
},
990+
{
991+
"code": "reportUnknownMemberType",
992+
"range": {
993+
"startColumn": 16,
994+
"endColumn": 28,
995+
"lineCount": 1
996+
}
997+
},
998+
{
999+
"code": "reportUnknownMemberType",
1000+
"range": {
1001+
"startColumn": 15,
1002+
"endColumn": 38,
1003+
"lineCount": 1
1004+
}
9571005
}
9581006
],
9591007
"./arraycontext/fake_numpy.py": [

.github/workflows/ci.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ jobs:
3838
- uses: actions/checkout@v6
3939
- name: "Main Script"
4040
run: |
41-
EXTRA_INSTALL="pytest types-colorama types-Pygments"
41+
EXTRA_INSTALL="pytest types-colorama types-Pygments scipy-stubs"
4242
curl -L -O https://tiker.net/ci-support-v0
4343
. ./ci-support-v0
4444
@@ -54,6 +54,7 @@ jobs:
5454
- uses: actions/checkout@v6
5555
- name: "Main Script"
5656
run: |
57+
EXTRA_INSTALL="scipy"
5758
curl -L -O https://tiker.net/ci-support-v0
5859
. ./ci-support-v0
5960
build_py_project_in_conda_env
@@ -75,6 +76,7 @@ jobs:
7576
export PYOPENCL_TEST=intel
7677
source /opt/enable-intel-cl.sh
7778
79+
EXTRA_INSTALL="scipy"
7880
curl -L -O https://tiker.net/ci-support-v0
7981
. ./ci-support-v0
8082
build_py_project_in_conda_env

.gitlab-ci.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
Python 3 POCL:
22
script: |
33
export PYOPENCL_TEST=portable:cpu
4-
export EXTRA_INSTALL="jax[cpu]"
4+
export EXTRA_INSTALL="jax[cpu] scipy"
55
export JAX_PLATFORMS=cpu
66
curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/build-and-test-py-project.sh
77
. ./build-and-test-py-project.sh
@@ -20,6 +20,7 @@ Python 3 Nvidia Titan V:
2020
curl -L -O https://tiker.net/ci-support-v0
2121
. ./ci-support-v0
2222
export PYOPENCL_TEST=nvi:titan
23+
export EXTRA_INSTALL="scipy"
2324
build_py_project_in_venv
2425
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
2526
test_py_project
@@ -38,6 +39,7 @@ Python 3 POCL Nvidia Titan V:
3839
curl -L -O https://tiker.net/ci-support-v0
3940
. ./ci-support-v0
4041
export PYOPENCL_TEST=port:titan
42+
export EXTRA_INSTALL="scipy"
4143
build_py_project_in_venv
4244
test_py_project
4345
@@ -66,6 +68,7 @@ Python 3 POCL Examples:
6668
Python 3 Conda:
6769
script: |
6870
export PYOPENCL_TEST=portable:cpu
71+
export EXTRA_INSTALL="scipy"
6972
7073
# Avoid crashes like https://gitlab.tiker.net/inducer/arraycontext/-/jobs/536021
7174
sed -i 's/jax/jax !=0.4.6/' .test-conda-env-py3.yml

arraycontext/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@
7373
from .context import (
7474
ArrayContext,
7575
ArrayContextFactory,
76+
CSRMatrix,
77+
SparseMatrix,
7678
tag_axes,
7779
)
7880
from .impl.jax import EagerJAXArrayContext
@@ -129,6 +131,7 @@
129131
"ArrayOrScalarT",
130132
"ArrayT",
131133
"BcastUntilActxArray",
134+
"CSRMatrix",
132135
"CommonSubexpressionTag",
133136
"ContainerOrScalarT",
134137
"EagerJAXArrayContext",
@@ -144,6 +147,7 @@
144147
"ScalarLike",
145148
"SerializationKey",
146149
"SerializedContainer",
150+
"SparseMatrix",
147151
"dataclass_array_container",
148152
"deserialize_container",
149153
"flat_size_and_dtype",

arraycontext/context.py

Lines changed: 194 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@
7878
7979
.. autoclass:: ArrayContext
8080
81+
.. autoclass:: SparseMatrix
82+
.. autoclass:: CSRMatrix
83+
8184
.. autofunction:: tag_axes
8285
8386
.. class:: P
@@ -114,13 +117,15 @@
114117
"""
115118

116119

120+
import dataclasses
117121
from abc import ABC, abstractmethod
118122
from collections.abc import Callable, Hashable, Mapping
119123
from typing import (
120124
TYPE_CHECKING,
121125
Any,
122126
ParamSpec,
123127
TypeAlias,
128+
cast,
124129
overload,
125130
)
126131
from warnings import warn
@@ -129,21 +134,27 @@
129134

130135
from pytools import memoize_method
131136

137+
from arraycontext.container.traversal import (
138+
rec_map_container,
139+
)
140+
132141

133142
if TYPE_CHECKING:
134143
import numpy as np
135144
from numpy.typing import DTypeLike
136145

137146
import loopy
138-
from pytools.tag import ToTagSetConvertible
147+
from pytools.tag import Tag, ToTagSetConvertible
139148

140149
from .fake_numpy import BaseFakeNumpyNamespace
141150
from .typing import (
142151
Array,
143152
ArrayContainerT,
144153
ArrayOrArithContainerOrScalarT,
154+
ArrayOrContainer,
145155
ArrayOrContainerOrScalar,
146156
ArrayOrContainerOrScalarT,
157+
ArrayOrScalar,
147158
ContainerOrScalarT,
148159
NumpyOrContainerOrScalar,
149160
ScalarLike,
@@ -152,6 +163,26 @@
152163

153164
P = ParamSpec("P")
154165

166+
_EMPTY_TAG_SET: frozenset[Tag] = frozenset()
167+
168+
169+
@dataclasses.dataclass(frozen=True, eq=False, repr=False)
170+
class SparseMatrix(ABC):
171+
shape: tuple[int, int]
172+
tags: ToTagSetConvertible = dataclasses.field(kw_only=True)
173+
axes: tuple[ToTagSetConvertible, ...] = dataclasses.field(kw_only=True)
174+
_actx: ArrayContext = dataclasses.field(kw_only=True)
175+
176+
def __matmul__(self, other: ArrayOrContainer) -> ArrayOrContainer:
177+
return self._actx.sparse_matmul(self, other)
178+
179+
180+
@dataclasses.dataclass(frozen=True, eq=False, repr=False)
181+
class CSRMatrix(SparseMatrix):
182+
elem_values: Array
183+
elem_col_indices: Array
184+
row_starts: Array
185+
155186

156187
# {{{ ArrayContext
157188

@@ -169,6 +200,8 @@ class ArrayContext(ABC):
169200
.. automethod:: to_numpy
170201
.. automethod:: call_loopy
171202
.. automethod:: einsum
203+
.. automethod:: make_csr_matrix
204+
.. automethod:: sparse_matmul
172205
.. attribute:: np
173206
174207
Provides access to a namespace that serves as a work-alike to
@@ -421,6 +454,166 @@ def einsum(self,
421454
)["out"]
422455
return self.tag(tagged, out_ary)
423456

457+
def make_csr_matrix(
458+
self,
459+
shape: tuple[int, int],
460+
elem_values: Array,
461+
elem_col_indices: Array,
462+
row_starts: Array,
463+
*,
464+
tags: ToTagSetConvertible = _EMPTY_TAG_SET,
465+
axes: tuple[ToTagSetConvertible, ...] | None = None) -> CSRMatrix:
466+
"""Return a sparse matrix in compressed sparse row (CSR) format, to be used
467+
with :meth:`sparse_matmul`.
468+
469+
:arg shape: the (two-dimensional) shape of the matrix
470+
:arg elem_values: a one-dimensional array containing the values of all of the
471+
nonzero entries of the matrix, grouped by row.
472+
:arg elem_col_indices: a one-dimensional array containing the column index
473+
values corresponding to each entry in *elem_values*.
474+
:arg row_starts: a one-dimensional array of length ``nrows+1``, where each entry
475+
gives the starting index in *elem_values* and *elem_col_indices* for the
476+
given row, with the last entry being equal to ``len(elem_values)``.
477+
"""
478+
if axes is None:
479+
axes = (frozenset(), frozenset())
480+
481+
return CSRMatrix(
482+
shape, elem_values, elem_col_indices, row_starts,
483+
tags=tags, axes=axes,
484+
_actx=self)
485+
486+
@memoize_method
487+
def _get_csr_matmul_prg(self, out_ndim: int) -> loopy.TranslationUnit:
488+
import loopy as lp
489+
490+
out_extra_inames = tuple(f"i{n}" for n in range(1, out_ndim))
491+
out_inames = ("irow", *out_extra_inames)
492+
out_inames_set = frozenset(out_inames)
493+
494+
out_extra_shape_comp_names = tuple(f"n{n}" for n in range(1, out_ndim))
495+
out_shape_comp_names = ("nrows", *out_extra_shape_comp_names)
496+
497+
domains: list[str] = []
498+
domains.append(
499+
"{ [" + ",".join(out_inames) + "] : "
500+
+ " and ".join(
501+
f"0 <= {iname} < {shape_comp_name}"
502+
for iname, shape_comp_name in zip(
503+
out_inames, out_shape_comp_names, strict=True))
504+
+ " }")
505+
domains.append(
506+
"{ [iel] : iel_lbound <= iel < iel_ubound }")
507+
508+
temporary_variables: Mapping[str, lp.TemporaryVariable] = {
509+
"iel_lbound": lp.TemporaryVariable(
510+
"iel_lbound",
511+
shape=(),
512+
),
513+
"iel_ubound": lp.TemporaryVariable(
514+
"iel_ubound",
515+
shape=(),
516+
)}
517+
518+
from loopy.kernel.instruction import make_assignment
519+
from pymbolic import var
520+
instructions: list[lp.Assignment | lp.CallInstruction] = [
521+
make_assignment(
522+
(var("iel_lbound"),),
523+
var("row_starts")[var("irow")],
524+
id="insn0",
525+
within_inames=out_inames_set),
526+
make_assignment(
527+
(var("iel_ubound"),),
528+
var("row_starts")[var("irow") + 1],
529+
id="insn1",
530+
within_inames=out_inames_set),
531+
make_assignment(
532+
(var("out")[tuple(var(iname) for iname in out_inames)],),
533+
lp.Reduction(
534+
"sum",
535+
(var("iel"),),
536+
var("elem_values")[var("iel"),]
537+
* var("array")[(
538+
var("elem_col_indices")[var("iel"),],
539+
*(var(iname) for iname in out_extra_inames))]),
540+
id="insn2",
541+
within_inames=out_inames_set,
542+
depends_on=frozenset({"insn0", "insn1"}))]
543+
544+
from loopy.version import MOST_RECENT_LANGUAGE_VERSION
545+
546+
from .loopy import _DEFAULT_LOOPY_OPTIONS
547+
548+
knl = lp.make_kernel(
549+
domains=domains,
550+
instructions=instructions,
551+
temporary_variables=temporary_variables,
552+
kernel_data=[
553+
lp.ValueArg("nrows", is_input=True),
554+
lp.ValueArg("ncols", is_input=True),
555+
lp.ValueArg("nels", is_input=True),
556+
*(
557+
lp.ValueArg(shape_comp_name, is_input=True)
558+
for shape_comp_name in out_extra_shape_comp_names),
559+
lp.GlobalArg("elem_values", shape=(var("nels"),), is_input=True),
560+
lp.GlobalArg("elem_col_indices", shape=(var("nels"),), is_input=True),
561+
lp.GlobalArg("row_starts", shape=lp.auto, is_input=True),
562+
lp.GlobalArg(
563+
"array",
564+
shape=(
565+
var("ncols"),
566+
*(
567+
var(shape_comp_name)
568+
for shape_comp_name in out_extra_shape_comp_names),),
569+
is_input=True),
570+
lp.GlobalArg(
571+
"out",
572+
shape=(
573+
var("nrows"),
574+
*(
575+
var(shape_comp_name)
576+
for shape_comp_name in out_extra_shape_comp_names),),
577+
is_input=False),
578+
...],
579+
name="csr_matmul_kernel",
580+
lang_version=MOST_RECENT_LANGUAGE_VERSION,
581+
options=_DEFAULT_LOOPY_OPTIONS,
582+
default_order=lp.auto,
583+
default_offset=lp.auto,
584+
)
585+
586+
idx_dtype = knl.default_entrypoint.index_dtype
587+
588+
return lp.add_and_infer_dtypes(
589+
knl,
590+
{
591+
",".join([
592+
"ncols", "nrows", "nels",
593+
*out_extra_shape_comp_names]): idx_dtype,
594+
"elem_col_indices,row_starts": idx_dtype})
595+
596+
def sparse_matmul(
597+
self, x1: SparseMatrix, x2: ArrayOrContainer) -> ArrayOrContainer:
598+
"""Multiply a sparse matrix by an array.
599+
600+
:arg x1: the sparse matrix.
601+
:arg x2: the array.
602+
"""
603+
if isinstance(x1, CSRMatrix):
604+
def _matmul(ary: ArrayOrScalar) -> ArrayOrScalar:
605+
assert self.is_array_type(ary)
606+
prg = self._get_csr_matmul_prg(len(ary.shape))
607+
return self.call_loopy(
608+
prg, elem_values=x1.elem_values,
609+
elem_col_indices=x1.elem_col_indices,
610+
row_starts=x1.row_starts, array=ary)["out"]
611+
612+
return cast("ArrayOrContainer", rec_map_container(_matmul, x2))
613+
614+
else:
615+
raise TypeError(f"unrecognized sparse matrix type '{type(x1).__name__}'")
616+
424617
@abstractmethod
425618
def clone(self) -> Self:
426619
"""If possible, return a version of *self* that is semantically

0 commit comments

Comments
 (0)