|
7 | 7 | import torch.nn as nn |
8 | 8 | from attr import dataclass |
9 | 9 |
|
| 10 | +try: |
| 11 | + from flash_attn.bert_padding import pad_input, unpad_input # type:ignore |
| 12 | + |
| 13 | + is_flash_attn_available = True |
| 14 | +except ImportError: |
| 15 | + pad_input = None |
| 16 | + unpad_input = None |
| 17 | + is_flash_attn_available = False |
| 18 | + |
10 | 19 | from esm.layers.regression_head import RegressionHead |
11 | 20 | from esm.layers.transformer_stack import TransformerStack |
12 | 21 | from esm.sdk.api import ( |
@@ -43,13 +52,26 @@ class ESMC(nn.Module, ESMCInferenceClient): |
43 | 52 | """ |
44 | 53 |
|
45 | 54 | def __init__( |
46 | | - self, d_model: int, n_heads: int, n_layers: int, tokenizer: EsmSequenceTokenizer |
| 55 | + self, |
| 56 | + d_model: int, |
| 57 | + n_heads: int, |
| 58 | + n_layers: int, |
| 59 | + tokenizer: EsmSequenceTokenizer, |
| 60 | + use_flash_attn: bool = True, |
47 | 61 | ): |
48 | 62 | super().__init__() |
49 | 63 | self.embed = nn.Embedding(64, d_model) |
| 64 | + |
| 65 | + self._use_flash_attn = is_flash_attn_available and use_flash_attn |
50 | 66 | self.transformer = TransformerStack( |
51 | | - d_model, n_heads, None, n_layers, n_layers_geom=0 |
| 67 | + d_model, |
| 68 | + n_heads, |
| 69 | + None, |
| 70 | + n_layers, |
| 71 | + n_layers_geom=0, |
| 72 | + use_flash_attn=self._use_flash_attn, |
52 | 73 | ) |
| 74 | + |
53 | 75 | self.sequence_head = RegressionHead(d_model, 64) |
54 | 76 | self.tokenizer = tokenizer |
55 | 77 |
|
@@ -109,10 +131,41 @@ def forward( |
109 | 131 |
|
110 | 132 | """ |
111 | 133 | if sequence_id is None: |
112 | | - sequence_id = sequence_tokens == self.tokenizer.pad_token_id |
| 134 | + # For EMSC, a boolean mask is created in place of sequence_id if not specified. |
| 135 | + sequence_id = sequence_tokens != self.tokenizer.pad_token_id |
113 | 136 |
|
114 | 137 | x = self.embed(sequence_tokens) |
| 138 | + |
| 139 | + B, L = x.shape[:2] |
| 140 | + |
| 141 | + # If sequence_id looks like a mask. |
| 142 | + if self._use_flash_attn: |
| 143 | + assert ( |
| 144 | + sequence_id.dtype == torch.bool |
| 145 | + ), "sequence_id must be a boolean mask if Flash Attention is used" |
| 146 | + assert sequence_id.shape == (B, L) |
| 147 | + assert unpad_input is not None |
| 148 | + x, indices, _, _, _ = unpad_input( # type: ignore |
| 149 | + x, sequence_id |
| 150 | + ) |
| 151 | + else: |
| 152 | + indices = None |
| 153 | + |
115 | 154 | x, _, hiddens = self.transformer(x, sequence_id=sequence_id) |
| 155 | + |
| 156 | + if self._use_flash_attn: |
| 157 | + assert indices is not None |
| 158 | + assert pad_input is not None |
| 159 | + x = pad_input(x, indices, B, L) # Back to [B, L, D] |
| 160 | + hiddens = [ |
| 161 | + # Back to [[B, L, D], ...] |
| 162 | + pad_input(h, indices, B, L) |
| 163 | + for h in hiddens |
| 164 | + ] |
| 165 | + |
| 166 | + # Stack hidden states into a [n_layers, B, L, D] matrix. |
| 167 | + hiddens = torch.stack(hiddens, dim=0) # type: ignore |
| 168 | + |
116 | 169 | sequence_logits = self.sequence_head(x) |
117 | 170 | output = ESMCOutput( |
118 | 171 | sequence_logits=sequence_logits, embeddings=x, hidden_states=hiddens |
@@ -161,4 +214,5 @@ def logits( |
161 | 214 | sequence=output.sequence_logits if config.sequence else None |
162 | 215 | ), |
163 | 216 | embeddings=output.embeddings if config.return_embeddings else None, |
| 217 | + hidden_states=output.hidden_states if config.return_hidden_states else None, |
164 | 218 | ) |
0 commit comments