1- from abc import ABC , abstractmethod
2- from collections .abc import Callable
3- from typing import cast
1+ from abc import ABC
42
53import torch
64from torch import Tensor , nn
75from torch .nn import Parameter
8- from torch .utils ._pytree import PyTree , tree_flatten , tree_map , tree_map_only
6+ from torch .utils ._pytree import tree_flatten
97
108# Note about import from protected _pytree module:
119# PyTorch maintainers plan to make pytree public (see
@@ -25,112 +23,26 @@ class JacobianComputer(ABC):
2523
2624 def __init__ (self , module : nn .Module ):
2725 self .module = module
28-
2926 self .rg_params = dict [str , Parameter ]()
30- self .frozen_params = dict [str , Parameter ]()
3127
3228 for name , param in module .named_parameters (recurse = True ):
3329 if param .requires_grad :
3430 self .rg_params [name ] = param
35- else :
36- self .frozen_params [name ] = param
37-
38- def __call__ (
39- self ,
40- rg_outputs : tuple [Tensor , ...],
41- grad_outputs : tuple [Tensor , ...],
42- args : tuple [PyTree , ...],
43- kwargs : dict [str , PyTree ],
44- ) -> Tensor :
45- # This makes __call__ vmappable.
46- return ComputeModuleJacobians .apply (
47- self ._compute_jacobian , rg_outputs , grad_outputs , args , kwargs
48- )
4931
50- @abstractmethod
51- def _compute_jacobian (
52- self ,
53- rg_outputs : tuple [Tensor , ...],
54- grad_outputs : tuple [Tensor , ...],
55- args : tuple [PyTree , ...],
56- kwargs : dict [str , PyTree ],
57- ) -> Tensor :
32+ def __call__ (self , rg_outputs : tuple [Tensor , ...], grad_outputs : tuple [Tensor , ...]) -> Tensor :
5833 """
59- Computes and returns the Jacobian. The output must be a matrix (2D Tensor).
34+ Computes and returns the Jacobian. The output must be a generalized Jacobian with param
35+ dimensions grouped.
6036 """
6137
6238
63- class FunctionalJacobianComputer (JacobianComputer ):
64- """
65- JacobianComputer using the functional differentiation API. This requires to use vmap, so it's
66- not compatible with every module, and it requires to have an extra forward pass to create the
67- vjp function.
68- """
69-
70- def _compute_jacobian (
71- self ,
72- _ : tuple [Tensor , ...],
73- grad_outputs : tuple [Tensor , ...],
74- args : tuple [PyTree , ...],
75- kwargs : dict [str , PyTree ],
76- ) -> Tensor :
77- grad_outputs_in_dims = (0 ,) * len (grad_outputs )
78- args_in_dims = tree_map (lambda t : 0 if isinstance (t , Tensor ) else None , args )
79- kwargs_in_dims = tree_map (lambda t : 0 if isinstance (t , Tensor ) else None , kwargs )
80- in_dims = (grad_outputs_in_dims , args_in_dims , kwargs_in_dims )
81- vmapped_vjp = torch .vmap (self ._call_on_one_instance , in_dims = in_dims )
82-
83- return vmapped_vjp (grad_outputs , args , kwargs )
84-
85- def _call_on_one_instance (
86- self ,
87- grad_outputs_j : tuple [Tensor , ...],
88- args_j : tuple [PyTree , ...],
89- kwargs_j : dict [str , PyTree ],
90- ) -> Tensor :
91- # Note: we use unsqueeze(0) to turn a single activation (or grad_output) into a
92- # "batch" of 1 activation (or grad_output). This is because some layers (e.g.
93- # nn.Flatten) do not work equivalently if they're provided with a batch or with
94- # an element of a batch. We thus always provide them with batches, just of a
95- # different size.
96- args_j = tree_map_only (torch .Tensor , lambda x : x .unsqueeze (0 ), args_j )
97- kwargs_j = tree_map_only (torch .Tensor , lambda x : x .unsqueeze (0 ), kwargs_j )
98- grad_outputs_j_ = tuple (x .unsqueeze (0 ) for x in grad_outputs_j )
99-
100- def functional_model_call (rg_params : dict [str , Parameter ]) -> tuple [Tensor , ...]:
101- all_state = [
102- cast (dict [str , Tensor ], rg_params ),
103- dict (self .module .named_buffers ()),
104- cast (dict [str , Tensor ], self .frozen_params ),
105- ]
106- output = torch .func .functional_call (self .module , all_state , args_j , kwargs_j )
107- flat_outputs = tree_flatten (output )[0 ]
108- rg_outputs = tuple (t for t in flat_outputs if isinstance (t , Tensor ) and t .requires_grad )
109- return rg_outputs
110-
111- vjp_func = torch .func .vjp (functional_model_call , self .rg_params )[1 ]
112-
113- # vjp_func is a function that computes the vjp w.r.t. to the primals (tuple). Here the
114- # functional has a single primal which is dict(module.named_parameters()). We therefore take
115- # the 0'th element to obtain the dict of gradients w.r.t. the module's named_parameters.
116- gradients = vjp_func (grad_outputs_j_ )[0 ]
117- gradient = torch .cat ([t .reshape (- 1 ) for t in gradients .values ()])
118- return gradient
119-
120-
12139class AutogradJacobianComputer (JacobianComputer ):
12240 """
12341 JacobianComputer using the autograd engine. The main advantage of using this method is that it
12442 doesn't require making an extra forward pass.
12543 """
12644
127- def _compute_jacobian (
128- self ,
129- rg_outputs : tuple [Tensor , ...],
130- grad_outputs : tuple [Tensor , ...],
131- _ : tuple [PyTree , ...],
132- __ : dict [str , PyTree ],
133- ) -> Tensor :
45+ def __call__ (self , rg_outputs : tuple [Tensor , ...], grad_outputs : tuple [Tensor , ...]) -> Tensor :
13446 flat_rg_params , ___ = tree_flatten (self .rg_params )
13547 grads = torch .autograd .grad (
13648 rg_outputs ,
@@ -141,47 +53,4 @@ def _compute_jacobian(
14153 materialize_grads = True ,
14254 )
14355 flattened_grads = torch .cat ([g .reshape (- 1 ) for g in grads ])
144- jacobian = flattened_grads .unsqueeze (0 )
145- return jacobian
146-
147-
148- class ComputeModuleJacobians (torch .autograd .Function ):
149- @staticmethod
150- def forward (
151- compute_jacobian_fn : Callable [
152- [tuple [Tensor , ...], tuple [Tensor , ...], tuple [PyTree , ...], dict [str , PyTree ]], Tensor
153- ],
154- rg_outputs : tuple [Tensor , ...],
155- grad_outputs : tuple [Tensor , ...],
156- args : tuple [PyTree , ...],
157- kwargs : dict [str , PyTree ],
158- ) -> Tensor :
159- # There is no non-batched dimension
160- jacobian = compute_jacobian_fn (rg_outputs , grad_outputs , args , kwargs )
161- return jacobian
162-
163- @staticmethod
164- def vmap (
165- _ ,
166- in_dims : tuple [None , None , tuple [int , ...], None , None ],
167- compute_jacobian_fn : Callable ,
168- rg_outputs : tuple [Tensor , ...],
169- jac_outputs : tuple [Tensor , ...],
170- args : tuple [PyTree , ...],
171- kwargs : dict [str , PyTree ],
172- ) -> tuple [Tensor , None ]:
173- # There is a non-batched dimension
174- # We do not vmap over the args, kwargs, or rg_outputs for the non-batched dimension
175- generalized_jacobian = torch .vmap (compute_jacobian_fn , in_dims = in_dims [1 :])(
176- rg_outputs ,
177- jac_outputs ,
178- args ,
179- kwargs ,
180- )
181- shape = generalized_jacobian .shape
182- jacobian = generalized_jacobian .reshape ([shape [0 ] * shape [1 ], - 1 ])
183- return jacobian , None
184-
185- @staticmethod
186- def setup_context (* _ ) -> None :
187- pass
56+ return flattened_grads
0 commit comments