Feature: Returning ESM-2 hidden layers#4
Open
JR-1991 wants to merge 5 commits intopatrick-kidger:mainfrom
Open
Feature: Returning ESM-2 hidden layers#4JR-1991 wants to merge 5 commits intopatrick-kidger:mainfrom
JR-1991 wants to merge 5 commits intopatrick-kidger:mainfrom
Conversation
The ESM2Result class now includes an all_hidden attribute containing outputs from all layers, not just the final layer. The ESM2 model's __call__ method is updated to collect and return these intermediate representations, providing more detailed model output for downstream analysis.
Changed ESM2Result.all_hidden from a list to a Float array with shape (num_layers, length, embed_size). Added a test to verify the shape of all_hidden output and ensure consistency for both token and string inputs.
Introduces a test to verify the shape of the all_hidden output from the ESM2 model for both tokenized and string inputs. Ensures that all_hidden has the expected dimensions based on the number of layers, sequence length, and embedding size.
The __len__ method, which returned the number of layers, has been removed from the ESM2 class. This simplifies the class interface and removes an unused or unnecessary method.
esm2quinox/_esm2.py
Outdated
|
|
||
| hidden: Float[Array, "length embed_size"] | ||
| logits: Float[Array, "length alphabet_size"] | ||
| all_hidden: Float[Array, "num_layers length embed_size"] |
Owner
There was a problem hiding this comment.
Thanks for the contribution!
Looking at this line, I think it would be more efficient to represent the hidden state as a list[Float[Array, "length embed_size"]].
The typical use-case in which a user needs only a few of the layers would then make it possible for the compiler to DCE the remaining elements of the list.
Author
There was a problem hiding this comment.
Thanks for the quick feedback! That makes sense, will update the PR :)
Modified ESM2Result to store all_hidden as a list of arrays instead of a single stacked array. Updated tests to check for list type and correct shapes for each element, ensuring compatibility with the new structure.
Author
|
@patrick-kidger, I’ve updated the PR to use a |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Following up on the suggestion we talked about a few months back, this pull request introduces support for retrieving all hidden layers of the ESM2 model for layer sweeps and inspections. To achieve this, this PR extends the
ESM2Resultclass and updates the model’s forward pass to gather and return these intermediate outputs. Tests have also been added to ensure the new features work as expected.Model output enhancements:
ESM2Resultclass now includes a new.all_hiddenattribute, which contains the hidden representations from all layers (including the final layer)._callmethod is updated to collect and return the outputs from all layers in the.all_hiddenfield.Testing:
test_all_hidden_output, to verify that the.all_hiddenoutput has the correct shape and is present when calling the model with both tokenized and string inputs.