⚡ Performance update ⚡
In this release, we updated torchjd to remove some of the unnecessary overhead in the internal code. This should lead to small but noticeable performance improvements (up to 10% speed).
We have also made torchjd more lightweight, by making optional some dependencies that were only used by CAGrad and NashMTL (the changelog explains how to keep installing these dependencies).
We have also fixed all internal type errors thanks to mypy, and we have added a py.typed file so mypy can be used downstream.
Changelog
Changed
- BREAKING: Changed the dependencies of
CAGradandNashMTLto be optional when installing
TorchJD. Users of these aggregators will have to usepip install torchjd[cagrad],pip install torchjd[nash_mtl]orpip install torchjd[full]to install TorchJD alongside those dependencies.
This should make TorchJD more lightweight. - BREAKING: Made the aggregator modules and the
autojacpackage protected. The aggregators
must now always be imported via their package (e.g.
from torchjd.aggregation.upgrad import UPGradmust be changed to
from torchjd.aggregation import UPGrad). Thebackwardandmtl_backwardfunctions must now
always be imported directly from thetorchjdpackage (e.g.
from torchjd.autojac.mtl_backward import mtl_backwardmust be changed to
from torchjd import mtl_backward). - Removed the check that the input Jacobian matrix provided to an aggregator does not contain
nan,
infor-infvalues. This check was costly in memory and in time for large matrices so this
should improve performance. However, if the optimization diverges for some reason (for instance
due to a too large learning rate), the resulting exceptions may come from other sources. - Removed some runtime checks on the shapes of the internal tensors used by the
autojacengine.
This should lead to a small performance improvement.
Fixed
- Made some aggregators (
CAGrad,ConFIG,DualProj,GradDrop,IMTLG,NashMTL,PCGrad
andUPGrad) raise aNonDifferentiableErrorwhenever one tries to differentiate through them.
Before this change, trying to differentiate through them leaded to wrong gradients or unclear
errors.