Skip to content

Sparse matrices#631

Merged
inducer merged 12 commits intoinducer:mainfrom
majosm:sparse-matrix
Mar 13, 2026
Merged

Sparse matrices#631
inducer merged 12 commits intoinducer:mainfrom
majosm:sparse-matrix

Conversation

@majosm
Copy link
Collaborator

@majosm majosm commented Jan 14, 2026

Adds CSRMatmul array type for multiplication by sparse matrices in compressed sparse row format.

cc @lukeolson

@majosm
Copy link
Collaborator Author

majosm commented Jan 14, 2026

@inducer I started looking through the loopy codegen code to try to figure out how to generate temporaries for the reduction bounds to avoid islpy errors (as is done in the loopy example), but I'm having some trouble figuring out how best to do that. Seems like I need to create the temporaries inside InlinedExpressionGenMapper.map_reduce() before adding the reduction domain to the kernel, but I'm not sure how to make sure the temporaries get the right scope, dependencies, inames, etc. Any suggestions on how I could approach this? Should I maybe be doing it inside CodeGenMapper.map_index_lambda by collecting the reduction info ahead of time and creating temporaries before calling the exprgen_mapper?

@inducer
Copy link
Owner

inducer commented Jan 14, 2026

Yes, working inside of CodeGenMapper.map_index_lambda seems like a possibly promising route. I'm not sur what info you would need to collect ahead of time though.

@majosm majosm force-pushed the sparse-matrix branch 5 times, most recently from de98230 to 2e1c083 Compare January 22, 2026 21:41
@majosm
Copy link
Collaborator Author

majosm commented Jan 22, 2026

@inducer I think this is ready for a look. I'm not sure how to fix the lingering doc build errors.

Here's what the results are looking like. Code:

A_pl = pt.make_csr_matrix(  # noqa: N806
    shape=(n-2, n),
    elem_values=pt.make_placeholder("A_elem_values", (3*(n-2),)),
    elem_col_indices=pt.make_placeholder("A_elem_col_indices", (3*(n-2),)),
    row_starts=pt.make_placeholder("A_row_starts", (n-1,)))
u_pl = pt.make_placeholder("u", n)
prog = pt.generate_loopy(A_pl @ u_pl)

Kernel:

---------------------------------------------------------------------------
KERNEL: _pt_kernel
---------------------------------------------------------------------------
ARGUMENTS:
A_elem_values: type: np:dtype('float64'), shape: (54), dim_tags: (N0:stride:1), offset: <class 'loopy.typing.auto'> in aspace: global
A_elem_col_indices: type: np:dtype('float64'), shape: (54), dim_tags: (N0:stride:1), offset: <class 'loopy.typing.auto'> in aspace: global
A_row_starts: type: np:dtype('float64'), shape: (19), dim_tags: (N0:stride:1), offset: <class 'loopy.typing.auto'> in aspace: global
u: type: np:dtype('float64'), shape: (20), dim_tags: (N0:stride:1), offset: <class 'loopy.typing.auto'> in aspace: global
_pt_out: type: np:dtype('float64'), shape: (18), dim_tags: (N0:stride:1) out aspace: global
---------------------------------------------------------------------------
DOMAINS:
{  :  }
{ [_pt_temp_dim0] : 0 <= _pt_temp_dim0 <= 17 }
  [_pt_sum_r0_lbound, _pt_sum_r0_ubound] -> { [_pt_sum_r0] : _pt_sum_r0_lbound <= _pt_sum_r0 < _pt_sum_r0_ubound }
{ [_pt_out_dim0] : 0 <= _pt_out_dim0 <= 17 }
---------------------------------------------------------------------------
INAME TAGS:
_pt_out_dim0: None
_pt_sum_r0: None
_pt_temp_dim0: None
---------------------------------------------------------------------------
TEMPORARIES:
_pt_sum_r0_lbound: type: np:dtype('int64'), shape: () aspace: global
_pt_sum_r0_ubound: type: np:dtype('int64'), shape: () aspace: global
_pt_temp: type: np:dtype('float64'), shape: (18), dim_tags: (N0:stride:1) aspace: global
---------------------------------------------------------------------------
INSTRUCTIONS:
    for _pt_temp_dim0
↱     _pt_sum_r0_lbound = A_row_starts[_pt_temp_dim0]  {id=_pt_sum_r0_lbound_store}
│↱    _pt_sum_r0_ubound = A_row_starts[_pt_temp_dim0 + 1]  {id=_pt_sum_r0_ubound_store}
└└↱   _pt_temp[_pt_temp_dim0] = reduce(sum, [_pt_sum_r0], A_elem_values[_pt_sum_r0]*u[A_elem_col_indices[_pt_sum_r0]])  {id=_pt_temp_store}
  │ end _pt_temp_dim0
  │ for _pt_out_dim0
  └   _pt_out[_pt_out_dim0] = _pt_temp[_pt_out_dim0]  {id=_pt_out_store}
    end _pt_out_dim0
---------------------------------------------------------------------------

@majosm majosm requested a review from inducer January 22, 2026 22:01
@majosm majosm marked this pull request as ready for review January 22, 2026 22:19
@majosm majosm force-pushed the sparse-matrix branch 2 times, most recently from 5159091 to f7f8699 Compare January 23, 2026 18:54
@majosm majosm requested a review from inducer January 23, 2026 21:59
@majosm
Copy link
Collaborator Author

majosm commented Feb 5, 2026

FWIW, while experimenting with containers in inducer/arraycontext#349, I tried seeing what this PR would look like if sparse matrix objects were treated as arrays. The changes are here. (Basically, matrices are treated as any other array except in codegen-related mappers, which refuse to process matrices that aren't reached through matmul arrays.)

@majosm majosm force-pushed the sparse-matrix branch 2 times, most recently from faa9bfc to 40eb010 Compare February 16, 2026 16:42
@majosm majosm force-pushed the sparse-matrix branch 3 times, most recently from 4675d86 to 2274478 Compare February 20, 2026 19:51
@majosm majosm requested a review from inducer February 20, 2026 21:29
@majosm majosm force-pushed the sparse-matrix branch 2 times, most recently from 87f5820 to 2fff24d Compare February 25, 2026 18:53
@inducer inducer requested a review from Copilot March 10, 2026 19:29
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds first-class sparse matrix support (CSR) to Pytato by introducing new array/node types and plumbing them through transforms, visualization, equality/analysis, and code generation.

Changes:

  • Introduces SparseMatrix/CSRMatrix and SparseMatmul/CSRMatmul, plus user-facing constructors (make_csr_matrix) and operation (sparse_matmul / @).
  • Extends core mappers/transforms/visualizers/equality/analysis to recognize and traverse CSRMatmul.
  • Updates Loopy codegen to handle more general reduction bounds (incl. non-affine detection/materialization) and adds CSR matmul tests.

Reviewed changes

Copilot reviewed 21 out of 21 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
test/test_pytato.py Adds input-validation tests for CSR matmul.
test/test_codegen.py Adds end-to-end Loopy codegen test for CSR matmul (single/multiple/stacked).
pytato/visualization/fancy_placeholder_data_flow.py Adds CSR matmul nodes to the “fancy” data-flow visualization.
pytato/visualization/dot.py Adds CSR matmul nodes/edges to DOT visualization output.
pytato/transform/metadata.py Ensures metadata collection traverses CSR matmul operands.
pytato/transform/materialize.py Enables materialization logic to rebuild/visit CSR matmul children without duplication.
pytato/transform/lower_to_index_lambda.py Lowers CSRMatmul to an IndexLambda using a symbolic reduction.
pytato/transform/einsum_distributive_law.py Integrates CSRMatmul into distributive-law transformation flow.
pytato/transform/init.py Adds CSR matmul handling to core transform mapper base classes/walkers.
pytato/target/python/numpy_like.py Explicitly marks CSR matmul unsupported for numpy-like targets.
pytato/target/loopy/codegen.py Substantial changes to reduction handling and storage to support new lowering paths.
pytato/stringifier.py Allows generic stringification of CSR matmul nodes.
pytato/scalar_expr.py Adds is_quasi_affine helper for reduction-bound analysis.
pytato/equality.py Adds structural equality support for CSR matmul nodes.
pytato/codegen.py Documents CSR matmul lowering in CodeGenPreprocessor.
pytato/array.py Adds sparse matrix types, constructors, and sparse_matmul implementation.
pytato/analysis/init.py Adds CSR matmul support to dependency/user analysis utilities.
pytato/init.py Exposes new sparse APIs/types at the package top level.
doc/conf.py Tweaks Sphinx nitpicks/missing-reference aliases for new typing/docs references.
.test-conda-env-py3.yml Adds matplotlib-base dependency (used optionally by visualization in tests).
.basedpyright/baseline.json Updates pyright baseline for new/changed typing diagnostics.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

Comment on lines +154 to +155
# FIXME: This mapper still needs to be updated to avoid duplicating arrays (see
# https://github.com/inducer/pytato/pull/515).
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think #644 took care of that, removing the FIXME.

@inducer inducer enabled auto-merge (squash) March 13, 2026 22:12
@inducer inducer merged commit e2a3111 into inducer:main Mar 13, 2026
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants