1+ import weakref
2+ from collections .abc import Callable
13from typing import cast
24
35import torch
1921# still support older versions of PyTorch where pytree is protected).
2022
2123
24+ class BoolRef :
25+ """Class wrapping a boolean value, acting as a reference to this boolean value."""
26+
27+ def __init__ (self , value : bool ):
28+ self .value = value
29+
30+ def __bool__ (self ) -> bool :
31+ return self .value
32+
33+
2234class ModuleHookManager :
2335 """
2436 Class responsible for handling hooks and Nodes that computes the Gramian reverse accumulation.
@@ -35,9 +47,19 @@ def __init__(
3547 ):
3648 self ._target_edges = target_edges
3749 self ._gramian_accumulator = gramian_accumulator
38- self .gramian_accumulation_phase = False
50+ self .gramian_accumulation_phase = BoolRef ( False )
3951 self ._handles : list [TorchRemovableHandle ] = []
4052
53+ # When the ModuleHookManager is not referenced anymore, there is no reason to keep the hooks
54+ # alive. In fact, keeping the hooks alive would also keep the target edges alive, which
55+ # would keep the graph or part of the graph alive. Since the graph contains nodes that store
56+ # the module in their context, which themselves reference their hooks, the hooks will be
57+ # caught in a reference cycle and will not be freed by the garbage collector. It is thus
58+ # important to remove the hooks whenever we're sure we won't need them anymore.
59+ # We could have used a __del__ method here, with the same effects, but weakref.finalize
60+ # seems to be a better practice (and it only works if the function to call is static).
61+ self ._finalizer = weakref .finalize (self , ModuleHookManager .remove_hooks , self ._handles )
62+
4163 def hook_module (self , module : nn .Module ) -> None :
4264 """
4365 Add a module hook used to insert Jacobian accumulation nodes into the backward graph.
@@ -46,85 +68,133 @@ def hook_module(self, module: nn.Module) -> None:
4668 enabling Gramian computation.
4769 """
4870
49- def module_hook (_ : nn .Module , args : PyTree , output : PyTree ) -> PyTree :
50- if self .gramian_accumulation_phase :
51- return output
52-
53- flat_outputs , tree_spec = tree_flatten (output )
71+ hook = Hook (self .gramian_accumulation_phase , self ._target_edges , self ._gramian_accumulator )
72+ self ._handles .append (module .register_forward_hook (hook ))
5473
55- if not any (isinstance (t , Tensor ) for t in flat_outputs ):
56- # This can happen only if a module returns no Tensor, for instance some niche usage
57- # such as a module that prints something.
58- return output
59-
60- requires_grad_params = [p for p in module .parameters (recurse = False ) if p .requires_grad ]
61- self ._gramian_accumulator .track_parameter_paths (requires_grad_params )
74+ @staticmethod
75+ def remove_hooks (handles : list [TorchRemovableHandle ]) -> None :
76+ """
77+ Remove all registered hooks. This method is deliberately static so that it can be called by
78+ weakref.finalize.
79+ """
6280
63- # We only care about running the JacobianAccumulator node, so we need one of its child
64- # edges (the edges of the original ouputs of the model) as target. For memory
65- # efficiency, we select the smallest one (that requires grad).
66- inf = float ("inf" )
67- preference = torch .tensor ([t .numel () if t .requires_grad else inf for t in flat_outputs ])
68- index = cast (int , preference .argmin ().item ())
69- self ._target_edges .register (get_gradient_edge (flat_outputs [index ]))
81+ for handle in handles :
82+ handle .remove ()
7083
71- return self ._apply_jacobian_accumulator (module , args , tree_spec , flat_outputs )
7284
73- handle = module .register_forward_hook (module_hook )
74- self ._handles .append (handle )
85+ class AccumulateJacobian (torch .autograd .Function ):
7586
76- def _apply_jacobian_accumulator (
77- self ,
78- module : nn .Module ,
79- args : PyTree ,
87+ @staticmethod
88+ def forward (
89+ ctx ,
8090 tree_spec : TreeSpec ,
81- flat_outputs : list [Tensor ],
82- ) -> PyTree :
83- vjp = torch .vmap (get_functional_vjp (module ))
84-
85- class AccumulateJacobian (torch .autograd .Function ):
86-
87- @staticmethod
88- def forward (* flat_grad_outputs : Tensor ) -> None :
89- grad_outputs = tree_unflatten (flat_grad_outputs , tree_spec )
90- jacobians = vjp (grad_outputs , args )
91- self ._gramian_accumulator .accumulate_path_jacobians (
92- {
93- module .get_parameter (param_name ): jacobian
94- for param_name , jacobian in jacobians .items ()
95- }
96- )
91+ vjp : Callable [[PyTree , PyTree ], dict [str , Tensor ]],
92+ args : PyTree ,
93+ gramian_accumulator : GramianAccumulator ,
94+ module : nn .Module ,
95+ * flat_grad_outputs : Tensor ,
96+ ) -> None :
97+ grad_outputs = tree_unflatten (flat_grad_outputs , tree_spec )
98+ jacobians = vjp (grad_outputs , args )
99+ gramian_accumulator .accumulate_path_jacobians (
100+ {
101+ module .get_parameter (param_name ): jacobian
102+ for param_name , jacobian in jacobians .items ()
103+ }
104+ )
105+
106+
107+ class JacobianAccumulator (torch .autograd .Function ):
108+ """
109+ Autograd function that accumulates Jacobian Gramians during the first backward pass.
97110
98- @staticmethod
99- def setup_context (* _ ):
100- pass
111+ Acts as identity on forward pass. During the autogram algorithm, computes the Jacobian
112+ of outputs w.r.t. module parameters and feeds it to the gramian accumulator. Uses a
113+ toggle mechanism to activate only during the Gramian accumulation phase.
114+ """
101115
102- class JacobianAccumulator (torch .autograd .Function ):
103- """
104- Autograd function that accumulates Jacobian Gramians during the first backward pass.
116+ generate_vmap_rule = True
105117
106- Acts as identity on forward pass. During the autogram algorithm, computes the Jacobian
107- of outputs w.r.t. module parameters and feeds it to the gramian accumulator. Uses a
108- toggle mechanism to activate only during the Gramian accumulation phase.
109- """
118+ @staticmethod
119+ def forward (
120+ ctx ,
121+ gramian_accumulation_phase : BoolRef ,
122+ tree_spec : TreeSpec ,
123+ vjp : Callable [[PyTree , PyTree ], dict [str , Tensor ]],
124+ args : PyTree ,
125+ gramian_accumulator : GramianAccumulator ,
126+ module : nn .Module ,
127+ * xs : Tensor ,
128+ ) -> tuple [Tensor , ...]:
129+ ctx .gramian_accumulation_phase = gramian_accumulation_phase
130+ ctx .tree_spec = tree_spec
131+ ctx .vjp = vjp
132+ ctx .args = args
133+ ctx .gramian_accumulator = gramian_accumulator
134+ ctx .module = module
135+ return tuple ([x .detach () for x in xs ])
136+
137+ @staticmethod
138+ def backward (ctx , * flat_grad_outputs : Tensor ):
139+ if not ctx .gramian_accumulation_phase :
140+ return None , None , None , None , None , None , * flat_grad_outputs
141+
142+ AccumulateJacobian .apply (
143+ ctx .tree_spec ,
144+ ctx .vjp ,
145+ ctx .args ,
146+ ctx .gramian_accumulator ,
147+ ctx .module ,
148+ * flat_grad_outputs ,
149+ )
150+
151+ return None , None , None , None , None , None , * flat_grad_outputs
152+
153+
154+ class Hook :
155+ def __init__ (
156+ self ,
157+ gramian_accumulation_phase : BoolRef ,
158+ target_edges : EdgeRegistry ,
159+ gramian_accumulator : GramianAccumulator ,
160+ ):
161+ self .gramian_accumulation_phase = gramian_accumulation_phase
162+ self .target_edges = target_edges
163+ self .gramian_accumulator = gramian_accumulator
110164
111- generate_vmap_rule = True
165+ def __call__ (self , module : nn .Module , args : PyTree , output : PyTree ) -> PyTree :
166+ if self .gramian_accumulation_phase :
167+ return output
112168
113- @staticmethod
114- def forward (* xs : Tensor ) -> tuple [Tensor , ...]:
115- return tuple ([x .detach () for x in xs ])
169+ flat_outputs , tree_spec = tree_flatten (output )
116170
117- @staticmethod
118- def setup_context (* _ ):
119- pass
171+ if not any (isinstance (t , Tensor ) for t in flat_outputs ):
172+ # This can happen only if a module returns no Tensor, for instance some niche usage
173+ # such as a module that prints something.
174+ return output
120175
121- @staticmethod
122- def backward (ctx , * flat_grad_outputs : Tensor ):
123- if not self .gramian_accumulation_phase :
124- return flat_grad_outputs
176+ requires_grad_params = [p for p in module .parameters (recurse = False ) if p .requires_grad ]
177+ self .gramian_accumulator .track_parameter_paths (requires_grad_params )
125178
126- AccumulateJacobian .apply (* flat_grad_outputs )
179+ # We only care about running the JacobianAccumulator node, so we need one of its child
180+ # edges (the edges of the original ouputs of the model) as target. For memory
181+ # efficiency, we select the smallest one (that requires grad).
182+ inf = float ("inf" )
183+ preference = torch .tensor ([t .numel () if t .requires_grad else inf for t in flat_outputs ])
184+ index = cast (int , preference .argmin ().item ())
185+ self .target_edges .register (get_gradient_edge (flat_outputs [index ]))
127186
128- return flat_grad_outputs
187+ vjp = torch . vmap ( get_functional_vjp ( module ))
129188
130- return tree_unflatten (JacobianAccumulator .apply (* flat_outputs ), tree_spec )
189+ return tree_unflatten (
190+ JacobianAccumulator .apply (
191+ self .gramian_accumulation_phase ,
192+ tree_spec ,
193+ vjp ,
194+ args ,
195+ self .gramian_accumulator ,
196+ module ,
197+ * flat_outputs ,
198+ ),
199+ tree_spec ,
200+ )
0 commit comments