|
20 | 20 | # still support older versions of PyTorch where pytree is protected). |
21 | 21 |
|
22 | 22 |
|
23 | | -class BoolRef: |
24 | | - """Class wrapping a boolean value, acting as a reference to this boolean value.""" |
25 | | - |
26 | | - def __init__(self, value: bool): |
27 | | - self.value = value |
28 | | - |
29 | | - def __bool__(self) -> bool: |
30 | | - return self.value |
31 | | - |
32 | | - |
33 | 23 | class ModuleHookManager: |
34 | 24 | """ |
35 | 25 | Class responsible for handling hooks and Nodes that computes the Gramian reverse accumulation. |
@@ -88,58 +78,64 @@ def remove_hooks(handles: list[TorchRemovableHandle]) -> None: |
88 | 78 | handle.remove() |
89 | 79 |
|
90 | 80 |
|
91 | | -class AccumulateJacobian(torch.autograd.Function): |
| 81 | +class BoolRef: |
| 82 | + """Class wrapping a boolean value, acting as a reference to this boolean value.""" |
92 | 83 |
|
93 | | - @staticmethod |
94 | | - def forward( |
95 | | - output_spec: TreeSpec, |
96 | | - vjp: VJP, |
97 | | - args: PyTree, |
98 | | - gramian_accumulator: GramianAccumulator, |
99 | | - module: nn.Module, |
100 | | - *flat_grad_outputs: Tensor, |
101 | | - ) -> None: |
102 | | - # There is no non-batched dimension |
103 | | - grad_outputs = tree_unflatten(flat_grad_outputs, output_spec) |
104 | | - generalized_jacobians = vjp(grad_outputs, args) |
105 | | - path_jacobians = AccumulateJacobian._make_path_jacobians(module, generalized_jacobians) |
106 | | - gramian_accumulator.accumulate_path_jacobians(path_jacobians) |
| 84 | + def __init__(self, value: bool): |
| 85 | + self.value = value |
107 | 86 |
|
108 | | - @staticmethod |
109 | | - def vmap( |
110 | | - _, |
111 | | - in_dims: PyTree, |
112 | | - output_spec: TreeSpec, |
113 | | - vjp: VJP, |
114 | | - args: PyTree, |
| 87 | + def __bool__(self) -> bool: |
| 88 | + return self.value |
| 89 | + |
| 90 | + |
| 91 | +class Hook: |
| 92 | + def __init__( |
| 93 | + self, |
| 94 | + gramian_accumulation_phase: BoolRef, |
| 95 | + target_edges: EdgeRegistry, |
115 | 96 | gramian_accumulator: GramianAccumulator, |
116 | | - module: nn.Module, |
117 | | - *flat_jac_outputs: Tensor, |
118 | | - ) -> tuple[None, None]: |
119 | | - # There is a non-batched dimension |
120 | | - jac_outputs = tree_unflatten(flat_jac_outputs, output_spec) |
121 | | - # We do not vmap over the args for the non-batched dimension |
122 | | - in_dims = (tree_unflatten(in_dims[5:], output_spec), tree_map(lambda _: None, args)) |
123 | | - generalized_jacobians = torch.vmap(vjp, in_dims=in_dims)(jac_outputs, args) |
124 | | - path_jacobians = AccumulateJacobian._make_path_jacobians(module, generalized_jacobians) |
125 | | - gramian_accumulator.accumulate_path_jacobians(path_jacobians) |
126 | | - return None, None |
| 97 | + has_batch_dim: bool, |
| 98 | + ): |
| 99 | + self.gramian_accumulation_phase = gramian_accumulation_phase |
| 100 | + self.target_edges = target_edges |
| 101 | + self.gramian_accumulator = gramian_accumulator |
| 102 | + self.has_batch_dim = has_batch_dim |
127 | 103 |
|
128 | | - @staticmethod |
129 | | - def _make_path_jacobians( |
130 | | - module: nn.Module, |
131 | | - generalized_jacobians: dict[str, Tensor], |
132 | | - ) -> dict[Tensor, Tensor]: |
133 | | - path_jacobians: dict[Tensor, Tensor] = {} |
134 | | - for param_name, generalized_jacobian in generalized_jacobians.items(): |
135 | | - key = module.get_parameter(param_name) |
136 | | - jacobian = generalized_jacobian.reshape([-1] + list(key.shape)) |
137 | | - path_jacobians[key] = jacobian |
138 | | - return path_jacobians |
| 104 | + def __call__(self, module: nn.Module, args: PyTree, output: PyTree) -> PyTree: |
| 105 | + if self.gramian_accumulation_phase: |
| 106 | + return output |
139 | 107 |
|
140 | | - @staticmethod |
141 | | - def setup_context(*_): |
142 | | - pass |
| 108 | + flat_outputs, output_spec = tree_flatten(output) |
| 109 | + |
| 110 | + if not any(isinstance(t, Tensor) and t.requires_grad for t in flat_outputs): |
| 111 | + # This can happen only if a module has a trainable param but outputs no tensor that |
| 112 | + # require grad |
| 113 | + return output |
| 114 | + |
| 115 | + requires_grad_params = [p for p in module.parameters(recurse=False) if p.requires_grad] |
| 116 | + self.gramian_accumulator.track_parameter_paths(requires_grad_params) |
| 117 | + |
| 118 | + # We only care about running the JacobianAccumulator node, so we need one of its child |
| 119 | + # edges (the edges of the original outputs of the model) as target. For memory |
| 120 | + # efficiency, we select the smallest one (that requires grad). |
| 121 | + inf = float("inf") |
| 122 | + preference = torch.tensor([t.numel() if t.requires_grad else inf for t in flat_outputs]) |
| 123 | + index = cast(int, preference.argmin().item()) |
| 124 | + self.target_edges.register(get_gradient_edge(flat_outputs[index])) |
| 125 | + |
| 126 | + vjp = FunctionalVJP(module) if self.has_batch_dim else AutogradVJP(module, flat_outputs) |
| 127 | + |
| 128 | + autograd_fn_outputs = JacobianAccumulator.apply( |
| 129 | + self.gramian_accumulation_phase, |
| 130 | + output_spec, |
| 131 | + vjp, |
| 132 | + args, |
| 133 | + self.gramian_accumulator, |
| 134 | + module, |
| 135 | + *flat_outputs, |
| 136 | + ) |
| 137 | + |
| 138 | + return tree_unflatten(autograd_fn_outputs, output_spec) |
143 | 139 |
|
144 | 140 |
|
145 | 141 | class JacobianAccumulator(torch.autograd.Function): |
@@ -197,51 +193,55 @@ def backward(ctx, *flat_grad_outputs: Tensor): |
197 | 193 | return None, None, None, None, None, None, *flat_grad_outputs |
198 | 194 |
|
199 | 195 |
|
200 | | -class Hook: |
201 | | - def __init__( |
202 | | - self, |
203 | | - gramian_accumulation_phase: BoolRef, |
204 | | - target_edges: EdgeRegistry, |
205 | | - gramian_accumulator: GramianAccumulator, |
206 | | - has_batch_dim: bool, |
207 | | - ): |
208 | | - self.gramian_accumulation_phase = gramian_accumulation_phase |
209 | | - self.target_edges = target_edges |
210 | | - self.gramian_accumulator = gramian_accumulator |
211 | | - self.has_batch_dim = has_batch_dim |
212 | | - |
213 | | - def __call__(self, module: nn.Module, args: PyTree, output: PyTree) -> PyTree: |
214 | | - if self.gramian_accumulation_phase: |
215 | | - return output |
216 | | - |
217 | | - flat_outputs, output_spec = tree_flatten(output) |
218 | | - |
219 | | - if not any(isinstance(t, Tensor) and t.requires_grad for t in flat_outputs): |
220 | | - # This can happen only if a module has a trainable param but outputs no tensor that |
221 | | - # require grad |
222 | | - return output |
223 | | - |
224 | | - requires_grad_params = [p for p in module.parameters(recurse=False) if p.requires_grad] |
225 | | - self.gramian_accumulator.track_parameter_paths(requires_grad_params) |
| 196 | +class AccumulateJacobian(torch.autograd.Function): |
226 | 197 |
|
227 | | - # We only care about running the JacobianAccumulator node, so we need one of its child |
228 | | - # edges (the edges of the original outputs of the model) as target. For memory |
229 | | - # efficiency, we select the smallest one (that requires grad). |
230 | | - inf = float("inf") |
231 | | - preference = torch.tensor([t.numel() if t.requires_grad else inf for t in flat_outputs]) |
232 | | - index = cast(int, preference.argmin().item()) |
233 | | - self.target_edges.register(get_gradient_edge(flat_outputs[index])) |
| 198 | + @staticmethod |
| 199 | + def forward( |
| 200 | + output_spec: TreeSpec, |
| 201 | + vjp: VJP, |
| 202 | + args: PyTree, |
| 203 | + gramian_accumulator: GramianAccumulator, |
| 204 | + module: nn.Module, |
| 205 | + *flat_grad_outputs: Tensor, |
| 206 | + ) -> None: |
| 207 | + # There is no non-batched dimension |
| 208 | + grad_outputs = tree_unflatten(flat_grad_outputs, output_spec) |
| 209 | + generalized_jacobians = vjp(grad_outputs, args) |
| 210 | + path_jacobians = AccumulateJacobian._make_path_jacobians(module, generalized_jacobians) |
| 211 | + gramian_accumulator.accumulate_path_jacobians(path_jacobians) |
234 | 212 |
|
235 | | - vjp = FunctionalVJP(module) if self.has_batch_dim else AutogradVJP(module, flat_outputs) |
| 213 | + @staticmethod |
| 214 | + def vmap( |
| 215 | + _, |
| 216 | + in_dims: PyTree, |
| 217 | + output_spec: TreeSpec, |
| 218 | + vjp: VJP, |
| 219 | + args: PyTree, |
| 220 | + gramian_accumulator: GramianAccumulator, |
| 221 | + module: nn.Module, |
| 222 | + *flat_jac_outputs: Tensor, |
| 223 | + ) -> tuple[None, None]: |
| 224 | + # There is a non-batched dimension |
| 225 | + jac_outputs = tree_unflatten(flat_jac_outputs, output_spec) |
| 226 | + # We do not vmap over the args for the non-batched dimension |
| 227 | + in_dims = (tree_unflatten(in_dims[5:], output_spec), tree_map(lambda _: None, args)) |
| 228 | + generalized_jacobians = torch.vmap(vjp, in_dims=in_dims)(jac_outputs, args) |
| 229 | + path_jacobians = AccumulateJacobian._make_path_jacobians(module, generalized_jacobians) |
| 230 | + gramian_accumulator.accumulate_path_jacobians(path_jacobians) |
| 231 | + return None, None |
236 | 232 |
|
237 | | - autograd_fn_outputs = JacobianAccumulator.apply( |
238 | | - self.gramian_accumulation_phase, |
239 | | - output_spec, |
240 | | - vjp, |
241 | | - args, |
242 | | - self.gramian_accumulator, |
243 | | - module, |
244 | | - *flat_outputs, |
245 | | - ) |
| 233 | + @staticmethod |
| 234 | + def _make_path_jacobians( |
| 235 | + module: nn.Module, |
| 236 | + generalized_jacobians: dict[str, Tensor], |
| 237 | + ) -> dict[Tensor, Tensor]: |
| 238 | + path_jacobians: dict[Tensor, Tensor] = {} |
| 239 | + for param_name, generalized_jacobian in generalized_jacobians.items(): |
| 240 | + key = module.get_parameter(param_name) |
| 241 | + jacobian = generalized_jacobian.reshape([-1] + list(key.shape)) |
| 242 | + path_jacobians[key] = jacobian |
| 243 | + return path_jacobians |
246 | 244 |
|
247 | | - return tree_unflatten(autograd_fn_outputs, output_spec) |
| 245 | + @staticmethod |
| 246 | + def setup_context(*_): |
| 247 | + pass |
0 commit comments