Efficient Merging/Pruning Methods: Multi-Modal Transformer Models #20708
Replies: 1 comment
-
I found a bug related to how I was jitting functions in my performance benchmark tests. It now looks like models run faster with compression as expected given the current implementations. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi Jax Community,
I am working on merging/pruning methods for a multi-modal transformer architecture in Flax. I am finding it challenging to get attention layers with both jax-based pruning and merging methods to work efficiently in comparison to regular attention blocks. FLOPs decreases but the model inference takes longer in the pruning/merging implementation despite lower FLOPs.
I had the following questions, I was hoping someone may be able to provide advice on:
For my input embedding sequence I always subset the sequence based on the modality. In my architecture setup method I precompute the indices of the sequence for each modality and pass these as static args to a prune/merge method. Within my prune/merge method I am currently using
jax.lax.dynamic_slice
to get my modality subsets. I have also experimented with usingjnp.take
but the performance difference seemed negligible based on initial tests. Is there a recommended way to efficiently slice/subset arrays given I know the indices and can precompile my merge/prune methods with these indices as static_args (for reference in setup of the following Module I precompute these methods, and they are defined in the following file). Would it make more sense to generate masks here for the modalities rather than indexing?For pruning I am applying
jax.lax.approx_max_k
. I am trying to refactor my code so I can vmap this method over modality subsets but I am finding that thek
param requires an int and so I haven't been able to vmap over this arg; I also didn't want to compile a method for each variation onk
. My original implementation of top_k pruning used a for loop over values; the compiled version of this method when added to may attention layer seems to add overhead and not improve the inference speed over the original attention layer implementation.I am going to continue debugging and will post a resolution should I find one, if anyone spots something obvious in the meantime it would be appreciated if they could chip in their advice on this.
Current Colab where I am debugging: https://colab.research.google.com/drive/1B1mg7r11d9DsPjHs1rwP4_oPXBAn7GHY#scrollTo=C3UQHvhOaqWU
Beta Was this translation helpful? Give feedback.
All reactions