- 
                Notifications
    You must be signed in to change notification settings 
- Fork 10
          feat(autogram): Add ModuleBasedGramianComputer.
          #458
        
          New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| This is big for architectures with very big linear layers, like AlexNet. 
 For other architectures, the differences are not very noticeable though. But this is very promising. Let's fully focus on this direction IMO. | 
| """ | ||
|  | ||
| G_b = torch.einsum("ik,jk->ij", dY1, dY2) | ||
| G_W = torch.einsum("ik,il,jl,jk->ij", dY1, X, X, dY2) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be replaced by:
G_W = oe.contract("ik,il,jl,jk->ij", dY1, X, X, dY2, optimize="optimal", backend="torch")with import opt_einsum as oe
but it seems to be the exact same runtime and memory usage.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually whenever opt_einsum is installed, the contraction is already done even without changing the line:
G_W = torch.einsum("ik,il,jl,jk->ij", dY1, X, X, dY2)We could still add the line just to make it explicit maybe.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
whatever you prefer. I prefer not having to give the two additional parameters, for me what is important here is
- It is an einsum
- It is fast
But the second criteria is more of an "how" than a "what" so we don't really need to know. For this reason I would vouch slightly for torch.einsum. The negative part is that a user could set the global settings of opt_einsum to non-optimized thereby making it slow, but I guess that is the user's responsability.
|  | ||
| def _compute_gramian(self, dY1: Tensor, dY2: Tensor, X: Tensor) -> Tensor: | ||
| """ | ||
| X is a matrix of shape [k, n] and dY1, dY2 are matrices of shape [k, m]. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made it work on Transformers with that:
if dY1.ndim == 1:
    G_b = torch.einsum("k,k->", dY1, dY2)
    G_W = torch.einsum("k,l,l,k->", dY1, X, X, dY2)
elif dY1.ndim == 2:
    G_b = torch.einsum("ak,ik->ai", dY1, dY2)
    G_W = torch.einsum("ak,al,il,ik->ai", dY1, X, X, dY2)
elif dY1.ndim == 3:  # Typical in transformers
    G_b = torch.einsum("abk,ijk->ai", dY1, dY2)
    G_W = torch.einsum("abk,abl,ijl,ijk->ai", dY1, X, X, dY2)
else:
    raise ValueError("Higher dimensions not supported. Open an issue if needed.")Not elegant at all but it seems to work. Maybe there's a clean way to write this that works for any number of dimensions without having ifs. Also, please review the equations. I did them basically with trial and error until the tests passed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well it needs to be at least matrices (2<=ndim) as we know it's a batched scenario. However, We could in principle add the no batched dimension scenario, but I'm not sure it would be faster than the classical Jacobian based GramianComputer.
I did them basically with trial and error until the tests passed.
I which I could have done that ^^
| Also pretty big for Transformers (with the change i suggested to handle higher order tensors). Times for forward + backward on WithTransformerLarge with BS=256, A=Mean on cuda:0. Memory is however increased. Max batch size went from 273 (main) to 256 (this PR). | 
| This seems to break  For  | 
| Of interest: https://optimized-einsum.readthedocs.io/en/stable/reusing_paths.html | 
LinearBasedGramianComputer.ModuleBasedGramianComputer.
      | @PierreQuinton I found a way to compute the gramian with autograd with no cross terms from module reuse / inter-module param reuse: 30fdc00. Basically, the idea is to have a module pre-hook that clones each parameter before using them, and a module post-hook that restores to the original params. This way, each module usage corresponds to a different clone, and you can compute a gradient wrt each clone. The implementation is with a context manager so it's quite clean IMO. Current limitations: 
 | 
7cc40ca    to
    0ae4695      
    Compare
  
    * Move conftest.py from tests.unit to tests * Separate DEVICE creation from conftest.py to device.py * Add pytest_make_parametrize_id
* Add support for RNN, BatchNorm2d and InstanceNorm2d in get_in_out_shapes * Remove WithRNN, WithBatchNorm and WithModuleTrackingRunningStats - use simple factories instead
* Revert removal of WithRNN (part of aff0abc) * Fix output of WithRNN to not include the hidden state
* Extract rng forking into contexts.py * Make _forward_pass do rng forking * Make _forward_pass take reduction parameter * Make forward_pass public * Use forward_pass in test_engine.py, stop reseeding (it's now done by forward_pass) * Make zipping strict in make_mse_loss_fn * Stop requiring params in autograd_gramian_forward_backward * Improve parameter order of autogram_forward_backward * Rename some variables * Factorize input and target creation into make_inputs_and_targets * Reorder some code
* Add CloneParams context to consider each parameter usage on a per-module-usage basis. * Add _get_losses_and_params_with_cross_terms, _get_losses_and_params_without_cross_terms, and _get_losses_and_params to select between both.
1d134b4    to
    7a95b96      
    Compare
  
    
No description provided.