11import math
22from collections .abc import Callable , Sequence
33from functools import partial
4- from itertools import accumulate
54
65import torch
7- from torch import Size , Tensor
6+ from torch import Tensor
87
98from ._differentiate import Differentiate
10- from ._materialize import materialize
119from ._ordered_set import OrderedSet
1210
1311
@@ -69,53 +67,37 @@ def _differentiate(self, jac_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]:
6967 ]
7068 )
7169
72- def _get_vjp (grad_outputs : Sequence [Tensor ], retain_graph : bool ) -> Tensor :
73- optional_grads = torch .autograd .grad (
74- self .outputs ,
75- self .inputs ,
76- grad_outputs = grad_outputs ,
77- retain_graph = retain_graph ,
78- create_graph = self .create_graph ,
79- allow_unused = True ,
80- )
81- grads = materialize (optional_grads , inputs = self .inputs )
82- return torch .concatenate ([grad .reshape ([- 1 ]) for grad in grads ])
83-
8470 # If the jac_outputs are correct, this value should be the same for all jac_outputs.
8571 m = jac_outputs [0 ].shape [0 ]
8672 max_chunk_size = self .chunk_size if self .chunk_size is not None else m
8773 n_chunks = math .ceil (m / max_chunk_size )
8874
89- # List of tensors of shape [k_i, n] where the k_i's sum to m
90- jac_matrix_chunks = []
75+ # One tuple per chunk (i), with one value per input (j), of shape [k_i] + shape[j],
76+ # where k_i is the number of rows in the chunk (the k_i's sum to m)
77+ jacs_chunks : list [tuple [Tensor , ...]] = []
9178
9279 # First differentiations: always retain graph
93- get_vjp_retain = partial (_get_vjp , retain_graph = True )
80+ get_vjp_retain = partial (self . _get_vjp , retain_graph = True )
9481 for i in range (n_chunks - 1 ):
9582 start = i * max_chunk_size
9683 end = (i + 1 ) * max_chunk_size
9784 jac_outputs_chunk = [jac_output [start :end ] for jac_output in jac_outputs ]
98- jac_matrix_chunks .append (_get_jac_matrix_chunk (jac_outputs_chunk , get_vjp_retain ))
85+ jacs_chunks .append (_get_jacs_chunk (jac_outputs_chunk , get_vjp_retain ))
9986
10087 # Last differentiation: retain the graph only if self.retain_graph==True
101- get_vjp_last = partial (_get_vjp , retain_graph = self .retain_graph )
88+ get_vjp_last = partial (self . _get_vjp , retain_graph = self .retain_graph )
10289 start = (n_chunks - 1 ) * max_chunk_size
10390 jac_outputs_chunk = [jac_output [start :] for jac_output in jac_outputs ]
104- jac_matrix_chunks .append (_get_jac_matrix_chunk (jac_outputs_chunk , get_vjp_last ))
105-
106- jac_matrix = torch .vstack (jac_matrix_chunks )
107- lengths = [input .numel () for input in self .inputs ]
108- jac_matrices = _extract_sub_matrices (jac_matrix , lengths )
109-
110- shapes = [input .shape for input in self .inputs ]
111- jacs = _reshape_matrices (jac_matrices , shapes )
91+ jacs_chunks .append (_get_jacs_chunk (jac_outputs_chunk , get_vjp_last ))
11292
113- return tuple (jacs )
93+ n_inputs = len (self .inputs )
94+ jacs = tuple (torch .cat ([chunks [i ] for chunks in jacs_chunks ]) for i in range (n_inputs ))
95+ return jacs
11496
11597
116- def _get_jac_matrix_chunk (
117- jac_outputs_chunk : list [Tensor ], get_vjp : Callable [[Sequence [Tensor ]], Tensor ]
118- ) -> Tensor :
98+ def _get_jacs_chunk (
99+ jac_outputs_chunk : list [Tensor ], get_vjp : Callable [[Sequence [Tensor ]], tuple [ Tensor , ...] ]
100+ ) -> tuple [ Tensor , ...] :
119101 """
120102 Computes the jacobian matrix chunk corresponding to the provided get_vjp function, either by
121103 calling get_vjp directly or by wrapping it into a call to ``torch.vmap``, depending on the shape
@@ -126,18 +108,7 @@ def _get_jac_matrix_chunk(
126108 chunk_size = jac_outputs_chunk [0 ].shape [0 ]
127109 if chunk_size == 1 :
128110 grad_outputs = [tensor .squeeze (0 ) for tensor in jac_outputs_chunk ]
129- gradient_vector = get_vjp (grad_outputs )
130- return gradient_vector .unsqueeze (0 )
111+ gradients = get_vjp (grad_outputs )
112+ return tuple ( gradient .unsqueeze (0 ) for gradient in gradients )
131113 else :
132114 return torch .vmap (get_vjp , chunk_size = chunk_size )(jac_outputs_chunk )
133-
134-
135- def _extract_sub_matrices (matrix : Tensor , lengths : Sequence [int ]) -> list [Tensor ]:
136- cumulative_lengths = [* accumulate (lengths )]
137- start_indices = [0 ] + cumulative_lengths [:- 1 ]
138- end_indices = cumulative_lengths
139- return [matrix [:, start :end ] for start , end in zip (start_indices , end_indices )]
140-
141-
142- def _reshape_matrices (matrices : Sequence [Tensor ], shapes : Sequence [Size ]) -> Sequence [Tensor ]:
143- return [matrix .view ((matrix .shape [0 ],) + shape ) for matrix , shape in zip (matrices , shapes )]
0 commit comments