Skip to content

Commit e2a3111

Browse files
majosminducerCopilot
authored
Sparse matrices (#631)
* add CSR sparse matrix multiplication array * add matplotlib to test conda env * add some checks for shapes of things in make_csr_matrix, and move some other checks over from constructor * minor wording change in comment * accept tuple[ToTagSetConvertible, ...] for axes in make_csr_matrix * check for non-affineness instead of subscripts * Fix doc build * Fix row_starts docs Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Matthew Smith <mjsmith6@illinois.edu> * fix placeholder names in tests * More structural validation in make_csr_matrix * Remove einsum distribute-law FIXME x-ref: gh-644 * Rename some variables in codegen --------- Co-authored-by: Andreas Kloeckner <inform@tiker.net> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent b25c311 commit e2a3111

21 files changed

+1004
-232
lines changed

.basedpyright/baseline.json

Lines changed: 64 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -2547,6 +2547,22 @@
25472547
"lineCount": 1
25482548
}
25492549
},
2550+
{
2551+
"code": "reportIncompatibleVariableOverride",
2552+
"range": {
2553+
"startColumn": 6,
2554+
"endColumn": 18,
2555+
"lineCount": 1
2556+
}
2557+
},
2558+
{
2559+
"code": "reportIncompatibleVariableOverride",
2560+
"range": {
2561+
"startColumn": 6,
2562+
"endColumn": 18,
2563+
"lineCount": 1
2564+
}
2565+
},
25502566
{
25512567
"code": "reportConstantRedefinition",
25522568
"range": {
@@ -2627,6 +2643,14 @@
26272643
"lineCount": 1
26282644
}
26292645
},
2646+
{
2647+
"code": "reportCallInDefaultInitializer",
2648+
"range": {
2649+
"startColumn": 43,
2650+
"endColumn": 54,
2651+
"lineCount": 1
2652+
}
2653+
},
26302654
{
26312655
"code": "reportAny",
26322656
"range": {
@@ -5617,6 +5641,14 @@
56175641
"lineCount": 1
56185642
}
56195643
},
5644+
{
5645+
"code": "reportUnannotatedClassAttribute",
5646+
"range": {
5647+
"startColumn": 4,
5648+
"endColumn": 18,
5649+
"lineCount": 1
5650+
}
5651+
},
56205652
{
56215653
"code": "reportUnannotatedClassAttribute",
56225654
"range": {
@@ -7041,134 +7073,6 @@
70417073
"lineCount": 1
70427074
}
70437075
},
7044-
{
7045-
"code": "reportImplicitOverride",
7046-
"range": {
7047-
"startColumn": 8,
7048-
"endColumn": 24,
7049-
"lineCount": 1
7050-
}
7051-
},
7052-
{
7053-
"code": "reportImplicitOverride",
7054-
"range": {
7055-
"startColumn": 8,
7056-
"endColumn": 23,
7057-
"lineCount": 1
7058-
}
7059-
},
7060-
{
7061-
"code": "reportImplicitOverride",
7062-
"range": {
7063-
"startColumn": 8,
7064-
"endColumn": 24,
7065-
"lineCount": 1
7066-
}
7067-
},
7068-
{
7069-
"code": "reportImplicitOverride",
7070-
"range": {
7071-
"startColumn": 8,
7072-
"endColumn": 17,
7073-
"lineCount": 1
7074-
}
7075-
},
7076-
{
7077-
"code": "reportImplicitOverride",
7078-
"range": {
7079-
"startColumn": 8,
7080-
"endColumn": 16,
7081-
"lineCount": 1
7082-
}
7083-
},
7084-
{
7085-
"code": "reportImplicitOverride",
7086-
"range": {
7087-
"startColumn": 8,
7088-
"endColumn": 28,
7089-
"lineCount": 1
7090-
}
7091-
},
7092-
{
7093-
"code": "reportImplicitOverride",
7094-
"range": {
7095-
"startColumn": 8,
7096-
"endColumn": 23,
7097-
"lineCount": 1
7098-
}
7099-
},
7100-
{
7101-
"code": "reportImplicitOverride",
7102-
"range": {
7103-
"startColumn": 8,
7104-
"endColumn": 19,
7105-
"lineCount": 1
7106-
}
7107-
},
7108-
{
7109-
"code": "reportImplicitOverride",
7110-
"range": {
7111-
"startColumn": 8,
7112-
"endColumn": 23,
7113-
"lineCount": 1
7114-
}
7115-
},
7116-
{
7117-
"code": "reportImplicitOverride",
7118-
"range": {
7119-
"startColumn": 8,
7120-
"endColumn": 18,
7121-
"lineCount": 1
7122-
}
7123-
},
7124-
{
7125-
"code": "reportImplicitOverride",
7126-
"range": {
7127-
"startColumn": 8,
7128-
"endColumn": 23,
7129-
"lineCount": 1
7130-
}
7131-
},
7132-
{
7133-
"code": "reportImplicitOverride",
7134-
"range": {
7135-
"startColumn": 8,
7136-
"endColumn": 29,
7137-
"lineCount": 1
7138-
}
7139-
},
7140-
{
7141-
"code": "reportImplicitOverride",
7142-
"range": {
7143-
"startColumn": 8,
7144-
"endColumn": 39,
7145-
"lineCount": 1
7146-
}
7147-
},
7148-
{
7149-
"code": "reportImplicitOverride",
7150-
"range": {
7151-
"startColumn": 8,
7152-
"endColumn": 28,
7153-
"lineCount": 1
7154-
}
7155-
},
7156-
{
7157-
"code": "reportImplicitOverride",
7158-
"range": {
7159-
"startColumn": 8,
7160-
"endColumn": 16,
7161-
"lineCount": 1
7162-
}
7163-
},
7164-
{
7165-
"code": "reportImplicitOverride",
7166-
"range": {
7167-
"startColumn": 8,
7168-
"endColumn": 29,
7169-
"lineCount": 1
7170-
}
7171-
},
71727076
{
71737077
"code": "reportPrivateUsage",
71747078
"range": {
@@ -7177,14 +7081,6 @@
71777081
"lineCount": 1
71787082
}
71797083
},
7180-
{
7181-
"code": "reportImplicitOverride",
7182-
"range": {
7183-
"startColumn": 8,
7184-
"endColumn": 24,
7185-
"lineCount": 1
7186-
}
7187-
},
71887084
{
71897085
"code": "reportUnannotatedClassAttribute",
71907086
"range": {
@@ -10741,6 +10637,38 @@
1074110637
"lineCount": 1
1074210638
}
1074310639
},
10640+
{
10641+
"code": "reportUnusedExpression",
10642+
"range": {
10643+
"startColumn": 4,
10644+
"endColumn": 9,
10645+
"lineCount": 1
10646+
}
10647+
},
10648+
{
10649+
"code": "reportUnusedExpression",
10650+
"range": {
10651+
"startColumn": 8,
10652+
"endColumn": 13,
10653+
"lineCount": 1
10654+
}
10655+
},
10656+
{
10657+
"code": "reportUnusedExpression",
10658+
"range": {
10659+
"startColumn": 8,
10660+
"endColumn": 13,
10661+
"lineCount": 1
10662+
}
10663+
},
10664+
{
10665+
"code": "reportUnusedExpression",
10666+
"range": {
10667+
"startColumn": 4,
10668+
"endColumn": 9,
10669+
"lineCount": 1
10670+
}
10671+
},
1074410672
{
1074510673
"code": "reportUnknownMemberType",
1074610674
"range": {

.test-conda-env-py3.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@ dependencies:
1515
- jax
1616
- openmpi # Force using Open MPI since our pytest infrastructure needs it
1717
- graphviz # for visualization tests
18+
- matplotlib-base

doc/conf.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,17 @@
6161

6262
# It's :data:, not :class:, but we can't tell autodoc that.
6363
["py:class", r"types\.EllipsisType"],
64+
# pytools
65+
# Got documented in Feb 2026, try removing?
66+
["py:class", "ToTagSetConvertible"],
6467
]
6568

6669

6770
sphinxconfig_missing_reference_aliases = {
6871
# pymbolic
6972
"ArithmeticExpression": "obj:pymbolic.ArithmeticExpression",
73+
# pytools
74+
"lp.TemporaryVariable": "class:loopy.TemporaryVariable",
7075
}
7176

7277

pytato/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ def set_debug_enabled(flag: bool) -> None:
5757
AxisPermutation,
5858
BasicIndex,
5959
Concatenate,
60+
CSRMatmul,
61+
CSRMatrix,
6062
DataWrapper,
6163
DictOfNamedArrays,
6264
Einsum,
@@ -70,6 +72,8 @@ def set_debug_enabled(flag: bool) -> None:
7072
Reshape,
7173
Roll,
7274
SizeParam,
75+
SparseMatmul,
76+
SparseMatrix,
7377
Stack,
7478
arange,
7579
broadcast_to,
@@ -87,6 +91,7 @@ def set_debug_enabled(flag: bool) -> None:
8791
logical_and,
8892
logical_not,
8993
logical_or,
94+
make_csr_matrix,
9095
make_data_wrapper,
9196
make_dict_of_named_arrays,
9297
make_placeholder,
@@ -99,6 +104,7 @@ def set_debug_enabled(flag: bool) -> None:
99104
reshape,
100105
roll,
101106
set_traceback_tag_enabled,
107+
sparse_matmul,
102108
squeeze,
103109
stack,
104110
transpose,
@@ -179,6 +185,8 @@ def set_debug_enabled(flag: bool) -> None:
179185
"Axis",
180186
"AxisPermutation",
181187
"BasicIndex",
188+
"CSRMatmul",
189+
"CSRMatrix",
182190
"Concatenate",
183191
"DataWrapper",
184192
"DictOfNamedArrays",
@@ -200,6 +208,8 @@ def set_debug_enabled(flag: bool) -> None:
200208
"Reshape",
201209
"Roll",
202210
"SizeParam",
211+
"SparseMatmul",
212+
"SparseMatrix",
203213
"Stack",
204214
"Target",
205215
"abs",
@@ -247,6 +257,7 @@ def set_debug_enabled(flag: bool) -> None:
247257
"logical_and",
248258
"logical_not",
249259
"logical_or",
260+
"make_csr_matrix",
250261
"make_data_wrapper",
251262
"make_dict_of_named_arrays",
252263
"make_distributed_recv",
@@ -273,6 +284,7 @@ def set_debug_enabled(flag: bool) -> None:
273284
"show_fancy_placeholder_data_flow",
274285
"sin",
275286
"sinh",
287+
"sparse_matmul",
276288
"sqrt",
277289
"squeeze",
278290
"stack",

pytato/analysis/__init__.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from pytato.array import (
3939
Array,
4040
Concatenate,
41+
CSRMatmul,
4142
DictOfNamedArrays,
4243
Einsum,
4344
IndexBase,
@@ -155,6 +156,20 @@ def map_einsum(self, expr: Einsum) -> None:
155156
self.array_to_users[dim].append(expr)
156157
self.rec(dim)
157158

159+
def map_csr_matmul(self, expr: CSRMatmul) -> None:
160+
for ary in (
161+
expr.matrix.elem_values,
162+
expr.matrix.elem_col_indices,
163+
expr.matrix.row_starts,
164+
expr.array):
165+
self.array_to_users[ary].append(expr)
166+
self.rec(ary)
167+
168+
for dim in expr.shape:
169+
if isinstance(dim, Array):
170+
self.array_to_users[dim].append(expr)
171+
self.rec(dim)
172+
158173
def map_named_array(self, expr: NamedArray) -> None:
159174
self.rec(expr._container)
160175

@@ -378,6 +393,14 @@ def map_concatenate(self, expr: Concatenate) -> list[ArrayOrNames]:
378393
def map_einsum(self, expr: Einsum) -> list[ArrayOrNames]:
379394
return self._get_preds_from_shape(expr.shape) + list(expr.args)
380395

396+
def map_csr_matmul(self, expr: CSRMatmul) -> list[ArrayOrNames]:
397+
return [
398+
*self._get_preds_from_shape(expr.shape),
399+
expr.matrix.elem_values,
400+
expr.matrix.elem_col_indices,
401+
expr.matrix.row_starts,
402+
expr.array]
403+
381404
def map_loopy_call(self, expr: LoopyCall) -> list[ArrayOrNames]:
382405
return [ary for ary in expr.bindings.values() if isinstance(ary, Array)]
383406

0 commit comments

Comments
 (0)