You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
refactor(autogram): Make engine hook recursively (#451)
* Make the Engine hook modules recursively. If a direct parameter exists, hook the module and do not hook its children. If not, try to hook the child modules.
* This change means that only the parentmost module with direct rg params gets hooked. Its used parameters are thus simply module.parameters(recurse=True) now. This is even the case in special cases where the parent uses the child parameters, so we don't need to have a special case for MHA anymore.
* Remove _module_utils.py: it's now trivial to know with respect to which parameters to differentiate.
* Update all usages to now create the Engine with Engine(model) instead of Engine(model.modules()). For partial JD, users have to be more careful, as they should sometimes specify several modules, but these modules should be "disjoint" (i.e. no specified module should be a child of another specified module)
* This mostly makes a difference on FreeParam. Before, we had 2 hooks (one for the parent, parameterized with the parent's param - aka the free param, and one for the child module, parameterized with the child's params). Now we simply have 1 hook for the parent, parameterized with the all parameters (i.e. parent.parameters(recurse=True)). This is probably faster (because we don't have to do 2 extra forwards and 2 extra backwards for the child, but just 1 now), but maybe a bit more memory consuming (because we have to store the Jacobian wrt the child's params and wrt the parent's free param at the same time). This case is quite niche though, and I still see it as an improvement.
* Change Engine to take *modules: nn.Module instead of Iterable[nn.Module] (more convenient for the new usage, because we only specify one model 99% of the time). Update the docstring accordingly.
0 commit comments