Skip to content

v0.4.0

Choose a tag to compare

@ValerianRey ValerianRey released this 02 Jan 20:36
· 287 commits to main since this release
1a8454e

Sequential differentiation improvements

This version provides some improvements to how backward and mtl_backward differentiate when parallel_chunk_size is such that not all tensors can be differentiated in parallel at once (for instance if parallel_chunk_size=2 but you have 3 losses).

In particular, when a single tensor has to be differentiated (e.g. when using parellel_chunk_size=1), we now avoid relying on torch.vmap, which has several issues.

The parameter retain_graph of backward and mtl_backward has also been changed to be only used during the last differentiation. In most cases, you can now simply use the default retain_graph=False (prior to this change, you had to use retain_graph=True if the differentiations were not all made in parallel at once). This should provide some improvements in terms of memory overhead.

Lastly, this update enables the usage of torchjd for training recurrent neural networks. As @lth456321 discovered, there can be an incompatibility between torch.vmap and torch.nn.RNN when running on CUDA. With this update, you can now simply set the parellel_chunk_size to 1 to avoid using torch.vmap and fix the problem. A usage example for RNNs has therefore been added to the documentation.

Changelog

Changed

  • Changed how the Jacobians are computed when calling backward or mtl_backward with
    parallel_chunk_size=1 to not rely on torch.autograd.vmap in this case. Whenever vmap does
    not support something (compiled functions, RNN on cuda, etc.), users should now be able to avoid
    using vmap by calling backward or mtl_backward with parallel_chunk_size=1.

  • Changed the effect of the parameter retain_graph of backward and mtl_backward. When set to
    False, it now frees the graph only after all gradients have been computed. In most cases, users
    should now leave the default value retain_graph=False, no matter what the value of
    parallel_chunk_size is. This will reduce the memory overhead.

Added

  • RNN training usage example in the documentation.

Contributors