Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
changes that do not affect the user.

## [Unreleased]
## [0.7.1] - 2025-06-12
### Added
- Seamless sparse-matrix support (SpMM and adjacency handling) for TorchJD, as SparseMatMul is currently not compatible with Jacobian Descent due to torch.vmap() dependencies.


## [0.7.0] - 2025-06-04

Expand Down
6 changes: 6 additions & 0 deletions docs/source/docs/sparse.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
:hide-toc:

sparse.sparse_mm
================

.. autofunction:: torchjd.sparse.sparse_mm
1 change: 1 addition & 0 deletions docs/source/examples/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ This section contains some usage examples for TorchJD.
basic_usage.rst
iwrm.rst
mtl.rst
sparse.rst
rnn.rst
monitoring.rst
lightning_integration.rst
Expand Down
34 changes: 34 additions & 0 deletions docs/source/examples/sparse.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
Quick example
==============================

TorchJD now offers helpers that make working with sparse adjacency matrices
transparent.
The key entry-point is :pyfunc:`torchjd.sparse.sparse_mm`,
a vmap-aware autograd function that replaces the usual
``torch.sparse.mm`` inside Jacobian Descent pipelines.

The snippet below shows how you can mix a sparse objective (involving
``A @ p``) with a dense one, then aggregate their Jacobians using
:pyclass:`torchjd.aggregation.UPGrad`.

.. doctest::

>>> import torch
>>> from torchjd import backward
>>> from torchjd.sparse import sparse_mm # patches torch automatically
>>> from torchjd.aggregation import UPGrad
>>>
>>> # 2×2 off-diagonal adjacency matrix
>>> A = torch.sparse_coo_tensor(
... indices=[[0, 1], [1, 0]],
... values=[1.0, 1.0],
... size=(2, 2)
... ).coalesce()
>>>
>>> p = torch.tensor([1.0, 2.0], requires_grad=True)
>>>
>>> y1 = sparse_mm(A, p.unsqueeze(1)).sum() # sparse term
>>> y2 = (p ** 2).sum() # dense term
>>> backward([y1, y2], UPGrad()) # Jacobian Descent step
>>> p.grad # doctest:+ELLIPSIS
tensor([1.0000, 1.6667])
7 changes: 0 additions & 7 deletions src/torchjd/__init__.py

This file was deleted.

2 changes: 2 additions & 0 deletions src/torchjd/_autojac/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from torchjd.sparse import sparse_mm

from ._backward import backward
from ._mtl_backward import mtl_backward
19 changes: 19 additions & 0 deletions src/torchjd/sparse/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Public interface for TorchJD sparse helpers.

Importing ``torchjd`` automatically activates seamless sparse support,
unless the environment variable ``TORCHJD_DISABLE_SPARSE`` is set to
``"1"`` **before** the first TorchJD import.
"""

from __future__ import annotations

import os

from ._autograd import sparse_mm # re-export
from ._patch import enable_seamless_sparse

__all__ = ["sparse_mm"]

# feature flag
if os.getenv("TORCHJD_DISABLE_SPARSE", "0") != "1":
enable_seamless_sparse()
62 changes: 62 additions & 0 deletions src/torchjd/sparse/_autograd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Vmap-compatible sparse @ dense for TorchJD."""

from __future__ import annotations

from typing import Tuple

import torch

from ._registry import to_coalesced_coo

_orig_sparse_mm = getattr(torch.sparse, "_orig_mm", torch.sparse.mm)


class _SparseMatMul(torch.autograd.Function):
"""y = A @ X where **A** is sparse and **X** is dense."""

@staticmethod
def forward(A_like: torch.Tensor, X: torch.Tensor) -> torch.Tensor: # noqa: D401
A = to_coalesced_coo(A_like)

if X.dim() == 3: # (B, N, d)
B, N, d = X.shape
X2d = X.reshape(B * N, d).view(N, B * d)
Y2d = _orig_sparse_mm(A, X2d) # pragma: no cover
return Y2d.view(N, B, d).permute(1, 0, 2) # pragma: no cover

return _orig_sparse_mm(A, X)

@staticmethod
def setup_context(ctx, inputs, output) -> None: # noqa: D401
A_like, _ = inputs
ctx.save_for_backward(to_coalesced_coo(A_like))

@staticmethod
def backward(ctx, dY: torch.Tensor) -> Tuple[None, torch.Tensor]:
(A,) = ctx.saved_tensors
AT = A.transpose(0, 1)

if dY.dim() == 3: # batched
B, N, d = dY.shape
dY2d = dY.permute(1, 0, 2).reshape(N, B * d)
dX2d = _orig_sparse_mm(AT, dY2d)
dX = dX2d.view(N, B, d).permute(1, 0, 2)
return None, dX

return None, _orig_sparse_mm(AT, dY) # pragma: no cover

@staticmethod
def vmap(info, in_dims, A_unbatched, X_batched): # noqa: D401
A = A_unbatched # shared
X = X_batched # (B, N, d)

B, N, d = X.shape
X2d = X.reshape(B * N, d).view(N, B * d)
Y2d = _orig_sparse_mm(A, X2d)
Y = Y2d.view(N, B, d).permute(1, 0, 2)
return Y, 0 # output & out-dims


def sparse_mm(A_like: torch.Tensor, X: torch.Tensor) -> torch.Tensor:
"""Return ``A @ X`` through the vmap-safe sparse Function."""
return _SparseMatMul.apply(A_like, X)
81 changes: 81 additions & 0 deletions src/torchjd/sparse/_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""Monkey-patch hooks that route sparse ops through TorchJD wrappers.

This module is imported from ``torchjd.sparse`` at import-time.
Patch execution is *idempotent* – calling :pyfunc:`enable_seamless_sparse`
multiple times is safe.
"""

from __future__ import annotations

import warnings
from importlib import import_module
from types import MethodType
from typing import Callable

import torch

from ._autograd import sparse_mm

# The wheel might exist yet be ABI-incompatible with the current
# PyTorch, which raises *OSError* at import-time.

try: # pragma: no cover
torch_sparse = import_module("torch_sparse") # type: ignore
except (ModuleNotFoundError, OSError):
torch_sparse = None


# Helpers
def _wrap_mm(orig_fn: Callable, wrapper: Callable) -> Callable:
"""Return a patched ``torch.sparse.mm`` that defers to *wrapper*."""

def patched(A, X): # noqa: D401
if isinstance(A, torch.Tensor) and A.is_sparse and X.dim() >= 2:
return wrapper(A, X)
return orig_fn(A, X)

Check warning on line 35 in src/torchjd/sparse/_patch.py

View check run for this annotation

Codecov / codecov/patch

src/torchjd/sparse/_patch.py#L35

Added line #L35 was not covered by tests

return patched


def _wrap_tensor_matmul(orig_fn: Callable) -> Callable:
def patched(self, other): # noqa: D401
if self.is_sparse and isinstance(other, torch.Tensor) and other.dim() >= 2:
return sparse_mm(self, other)
return orig_fn(self, other)

return patched


# Public API
def enable_seamless_sparse() -> None:
"""Patch common call-sites so users need *no* explicit imports."""
# torch.sparse.mm
if getattr(torch.sparse, "_orig_mm", None) is None:
torch.sparse._orig_mm = torch.sparse.mm # type: ignore[attr-defined]
torch.sparse.mm = _wrap_mm(torch.sparse._orig_mm, sparse_mm) # type: ignore[attr-defined]

# tensor @ dense
if getattr(torch.Tensor, "_orig_matmul", None) is None:
torch.Tensor._orig_matmul = torch.Tensor.__matmul__ # type: ignore[attr-defined] # noqa: E501
torch.Tensor.__matmul__ = _wrap_tensor_matmul(
torch.Tensor._orig_matmul # type: ignore[attr-defined]
) # type: ignore[attr-defined]

# torch_sparse (optional)
if torch_sparse is None:
warnings.warn(
"torch_sparse not found: SpSpMM will use slow fallback.",
RuntimeWarning,
stacklevel=2,
) # pragma: no cover
return

if not hasattr(torch_sparse.SparseTensor, "_orig_matmul"):

def _sparse_tensor_matmul(self, dense): # noqa: D401
return sparse_mm(self, dense)

Check warning on line 76 in src/torchjd/sparse/_patch.py

View check run for this annotation

Codecov / codecov/patch

src/torchjd/sparse/_patch.py#L76

Added line #L76 was not covered by tests

torch_sparse.SparseTensor._orig_matmul = torch_sparse.SparseTensor.matmul # type: ignore[attr-defined] # noqa: E501
torch_sparse.SparseTensor.matmul = MethodType( # type: ignore[attr-defined]
_sparse_tensor_matmul, torch_sparse.SparseTensor
)
11 changes: 11 additions & 0 deletions src/torchjd/sparse/_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Central registry of sparse conversions and helpers.

For now this file simply re-exports :func:`to_coalesced_coo`, but keeps
the door open for future registration logic.
"""

from __future__ import annotations

from ._utils import to_coalesced_coo

__all__ = ["to_coalesced_coo"]
37 changes: 37 additions & 0 deletions src/torchjd/sparse/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Utility helpers shared by the sparse sub-package."""

from __future__ import annotations

from typing import Any

import torch

try:
import importlib

torch_sparse = importlib.import_module("torch_sparse") # type: ignore
except (ModuleNotFoundError, OSError): # pragma: no cover
torch_sparse = None


def to_coalesced_coo(x: Any) -> torch.Tensor:
"""Convert *x* to a **coalesced** PyTorch sparse COO tensor."""

if isinstance(x, torch.Tensor) and x.is_sparse:
return x.coalesce()

if torch_sparse and isinstance(x, torch_sparse.SparseTensor): # type: ignore
return x.to_torch_sparse_coo_tensor().coalesce()

try:
import scipy.sparse as sp # pragma: no cover

if isinstance(x, sp.spmatrix):
coo = x.tocoo()
indices = torch.as_tensor([coo.row, coo.col], dtype=torch.long)
values = torch.as_tensor(coo.data, dtype=torch.get_default_dtype())
return torch.sparse_coo_tensor(indices, values, coo.shape).coalesce()
except ModuleNotFoundError: # pragma: no cover
pass

raise TypeError(f"Unsupported sparse type: {type(x)}") # pragma: no cover
2 changes: 1 addition & 1 deletion tests/doc/test_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
def test_backward():
import torch

from torchjd import backward
from torchjd._autojac import backward
from torchjd.aggregation import UPGrad

param = torch.tensor([1.0, 2.0], requires_grad=True)
Expand Down
14 changes: 7 additions & 7 deletions tests/doc/test_rst.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_basic_usage():
loss2 = loss_fn(output[:, 1], target2)

optimizer.zero_grad()
torchjd.backward([loss1, loss2], aggregator)
torchjd._autojac.backward([loss1, loss2], aggregator)
optimizer.step()


Expand Down Expand Up @@ -58,7 +58,7 @@ def test_iwrm_with_ssjd():
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD

from torchjd import backward
from torchjd._autojac import backward
from torchjd.aggregation import UPGrad

X = torch.randn(8, 16, 10)
Expand Down Expand Up @@ -87,7 +87,7 @@ def test_mtl():
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD

from torchjd import mtl_backward
from torchjd._autojac import mtl_backward
from torchjd.aggregation import UPGrad

shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
Expand Down Expand Up @@ -136,7 +136,7 @@ def test_lightning_integration():
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset

from torchjd import mtl_backward
from torchjd._autojac import mtl_backward
from torchjd.aggregation import UPGrad

class Model(LightningModule):
Expand Down Expand Up @@ -190,7 +190,7 @@ def test_rnn():
from torch.nn import RNN
from torch.optim import SGD

from torchjd import backward
from torchjd._autojac import backward
from torchjd.aggregation import UPGrad

rnn = RNN(input_size=10, hidden_size=20, num_layers=2)
Expand All @@ -215,7 +215,7 @@ def test_monitoring():
from torch.nn.functional import cosine_similarity
from torch.optim import SGD

from torchjd import mtl_backward
from torchjd._autojac import mtl_backward
from torchjd.aggregation import UPGrad

def print_weights(_, __, weights: torch.Tensor) -> None:
Expand Down Expand Up @@ -267,7 +267,7 @@ def test_amp():
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD

from torchjd import mtl_backward
from torchjd._autojac import mtl_backward
from torchjd.aggregation import UPGrad

shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/autojac/test_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torch.autograd import grad
from torch.testing import assert_close

from torchjd import backward
from torchjd._autojac import backward
from torchjd._autojac._backward import _create_transform
from torchjd._autojac._transform import OrderedSet
from torchjd.aggregation import MGDA, Aggregator, Mean, Random, Sum, UPGrad
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/autojac/test_mtl_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torch.autograd import grad
from torch.testing import assert_close

from torchjd import mtl_backward
from torchjd._autojac import mtl_backward
from torchjd._autojac._mtl_backward import _create_transform
from torchjd._autojac._transform import OrderedSet
from torchjd.aggregation import MGDA, Aggregator, Mean, Random, Sum, UPGrad
Expand Down
Loading
Loading