Skip to content

Commit 5a1f457

Browse files
authored
hidden states and flash attention (#168)
1 parent 41af605 commit 5a1f457

File tree

15 files changed

+282
-52
lines changed

15 files changed

+282
-52
lines changed

README.md

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
1-
- [Installation ](#installation)
2-
- [ESM C](#esm-c)
3-
- [ESM C 300M and 600M via GitHub](#esm-c-github)
4-
- [ESM C via Forge API for Free Non-Commercial Use](#esm-c-forge)
5-
- [ESM C via SageMaker for Commercial Use](#esm-c-sagemaker)
6-
- [ESM C Example Usage](#esmc-example)
7-
- [ESM 3](#esm3)
8-
- [Quickstart for ESM3-open](#esm3-quickstart)
9-
- [Forge: Access to larger ESM3 models](#esm3-forge)
10-
- [ESM 3 Example Usage](#esm3-example)
11-
- [Responsible Development ](#responsible-development)
12-
- [Licenses](#licenses)
1+
- [Installation ](#installation-)
2+
- [ESM C ](#esm-c-)
3+
- [ESM C Local Models via GitHub ](#esm-c-local-models-via-github-)
4+
- [Using ESM C 6B via Forge API](#using-esm-c-6b-via-forge-api)
5+
- [ESM C via Forge API for Free Non-Commercial Use ](#esm-c-via-forge-api-for-free-non-commercial-use--)
6+
- [ESM C via SageMaker for Commercial Use ](#esm-c-via-sagemaker-for-commercial-use--)
7+
- [ESM C Example Usage](#esm-c-example-usage)
8+
- [ESM 3 ](#esm-3--)
9+
- [Quickstart for ESM3-open ](#quickstart-for-esm3-open-)
10+
- [EvolutionaryScale Forge: Access to larger ESM3 models](#evolutionaryscale-forge-access-to-larger-esm3-models)
11+
- [ESM3 Example Usage](#esm3-example-usage)
12+
- [Responsible Development ](#responsible-development-)
13+
- [Licenses ](#licenses--)
14+
- [How can I access the models and which licenses apply?](#how-can-i-access-the-models-and-which-licenses-apply)
15+
- [What changed with the release of ESM C?](#what-changed-with-the-release-of-esm-c)
1316

1417

1518
## Installation <a name="installation"></a>
@@ -46,6 +49,16 @@ logits_output = client.logits(
4649
print(logits_output.logits, logits_output.embeddings)
4750
```
4851

52+
To use Flash Attention with the open weights:
53+
54+
Simply install flash-attn package, which will enable Flash Attention automatically:
55+
```
56+
pip install flash-attn --no-build-isolation
57+
```
58+
59+
You can also disable flash-attn by passing ``use_flash_attn=False`` to utils like ``ESMC_300M_202412``.
60+
61+
### Using ESM C 6B via Forge API
4962
### ESM C via Forge API for Free Non-Commercial Use <a name="esm-c-forge"></a>
5063

5164
The ESM C model family, including ESMC 6B, are accessible via EvolutionaryScale Forge for free [non-commercial use](#licenses).
@@ -235,13 +248,13 @@ The models can be accessed in three different ways, each with its own licensing
235248
1. **Code and weights** via GitHub and HuggingFace are available under either a [non-commercial](https://www.evolutionaryscale.ai/policies/cambrian-non-commercial-license-agreement) (ESM C 600M, ESM3-small-open) or an [open license](https://www.evolutionaryscale.ai/policies/cambrian-open-license-agreement) (codebase, ESM C 300M).
236249
1. **Building with ESM encouraged**: You can use embeddings, model predictions, fine-tune the models and use components of both the models and code. We strongly encourage anyone to build on ESM C and ESM3! Just remember to maintain the same license terms and release under the ESM name.
237250
2. **Free non-commercial inference API** via Forge. All models are available this way, with free credits granted to students and researchers. We want to enable academics under [non-commercial Terms of Use](https://www.evolutionaryscale.ai/policies/terms-of-use), which mirrors the non-commercial license.
238-
3. **Paid commercial Inference API** for commercial use via SageMaker (Forge coming soon). All ESM C models are available this way to commercial entities for commercial use under a [clickthrough license agreement](https://www.evolutionaryscale.ai/policies/cambrian-inference-clickthrough-license-agreement) with few restrictions.
251+
3. **Paid commercial Inference API** for commercial use via SageMaker (Forge coming soon). All ESM C models are available this way to commercial entities for commercial use under a [clickthrough license agreement](https://www.evolutionaryscale.ai/policies/cambrian-inference-clickthrough-license-agreement) with few restrictions.
239252
1. In broad strokes: standard commercial use like developing molecules and developing downstream ML models and methods with the model is allowed, while training competing models on the API outputs is not.
240-
2. Note: For ESM3 commercial use, reach out to [[email protected]](mailto:[email protected])
253+
2. Note: For ESM3 commercial use, reach out to [[email protected]](mailto:[email protected])
241254

242255
### What changed with the release of ESM C?
243256

244-
We introduced a [clickthrough license agreement](https://www.evolutionaryscale.ai/policies/cambrian-inference-clickthrough-license-agreement) to enable frictionless commercial use of ESM C.
257+
We introduced a [clickthrough license agreement](https://www.evolutionaryscale.ai/policies/cambrian-inference-clickthrough-license-agreement) to enable frictionless commercial use of ESM C.
245258

246259
We introduced the new [Cambrian Open License](https://www.evolutionaryscale.ai/policies/cambrian-open-license-agreement) for ESM C 300M, and at the same time moved all code in the [`esm` repo](https://github.com/evolutionaryscale/esm) under that permissive license.
247260

esm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
__version__ = "3.1.1"
1+
__version__ = "3.1.2"
22

esm/layers/attention.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,15 @@
55
import torch.nn.functional as F
66
from torch import nn
77

8-
from esm.layers.rotary import RotaryEmbedding
8+
from esm.layers.rotary import (
9+
RotaryEmbedding,
10+
TritonRotaryEmbedding,
11+
)
12+
13+
try:
14+
from flash_attn import flash_attn_varlen_qkvpacked_func # type:ignore
15+
except ImportError:
16+
flash_attn_varlen_func = None
917

1018

1119
class MultiHeadAttention(nn.Module):
@@ -49,9 +57,8 @@ def forward(self, x, seq_id):
4957
)
5058
query_BLD, key_BLD = self._apply_rotary(query_BLD, key_BLD)
5159

52-
n_heads = self.n_heads
5360
reshaper = functools.partial(
54-
einops.rearrange, pattern="b s (h d) -> b h s d", h=n_heads
61+
einops.rearrange, pattern="b s (h d) -> b h s d", h=self.n_heads
5562
)
5663

5764
query_BHLD, key_BHLD, value_BHLD = map(
@@ -72,5 +79,47 @@ def forward(self, x, seq_id):
7279
context_BHLD = F.scaled_dot_product_attention(
7380
query_BHLD, key_BHLD, value_BHLD
7481
)
82+
7583
context_BLD = einops.rearrange(context_BHLD, "b h s d -> b s (h d)")
84+
7685
return self.out_proj(context_BLD)
86+
87+
88+
class FlashMultiHeadAttention(MultiHeadAttention):
89+
def __init__(
90+
self, d_model: int, n_heads: int, bias: bool = False, qk_layernorm: bool = True
91+
):
92+
super().__init__(
93+
d_model=d_model, n_heads=n_heads, bias=bias, qk_layernorm=qk_layernorm
94+
)
95+
96+
# Flash attention rotary.
97+
self.rotary = TritonRotaryEmbedding(d_model // n_heads)
98+
99+
def forward(self, x, seq_id):
100+
assert seq_id.dtype == torch.bool
101+
102+
seqlens = seq_id.sum(dim=-1, dtype=torch.int32)
103+
cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0))
104+
max_seqlen = seqlens.max().item()
105+
106+
qkv_ND3 = self.layernorm_qkv(x)
107+
108+
query_ND, key_ND, value_ND = torch.chunk(qkv_ND3, 3, dim=-1)
109+
query_ND, key_ND = (
110+
self.q_ln(query_ND).to(query_ND.dtype),
111+
self.k_ln(key_ND).to(query_ND.dtype),
112+
)
113+
114+
qkv_N3D = torch.stack([query_ND, key_ND, value_ND], dim=1)
115+
qkv_N3HD = einops.rearrange(
116+
qkv_N3D, pattern="n a (h d) -> n a h d", h=self.n_heads
117+
)
118+
qkv_N3HD = self.rotary(qkv_N3HD, cu_seqlens, max_seqlen)
119+
120+
context_NHD = flash_attn_varlen_qkvpacked_func(
121+
qkv_N3HD, cu_seqlens, max_seqlen, softmax_scale=self.d_head**-0.5
122+
)
123+
context_ND = einops.rearrange(context_NHD, "n h d -> n (h d)")
124+
125+
return self.out_proj(context_ND)

esm/layers/blocks.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
import torch.nn as nn
33
import torch.nn.functional as F
44

5-
from esm.layers.attention import MultiHeadAttention
5+
from esm.layers.attention import (
6+
FlashMultiHeadAttention,
7+
MultiHeadAttention,
8+
)
69
from esm.layers.geom_attention import (
710
GeometricReasoningOriginalImpl,
811
)
@@ -78,6 +81,7 @@ def __init__(
7881
n_heads: int,
7982
use_geom_attn: bool = False,
8083
use_plain_attn: bool = True,
84+
use_flash_attn: bool = False,
8185
v_heads: int | None = None,
8286
bias: bool = False,
8387
expansion_ratio: float = 4.0,
@@ -89,9 +93,14 @@ def __init__(
8993
super().__init__()
9094
self.use_plain_attn = use_plain_attn
9195
if self.use_plain_attn:
92-
self.attn = MultiHeadAttention(
93-
d_model, n_heads, bias, qk_layernorm=qk_layernorm
94-
)
96+
if use_flash_attn:
97+
self.attn = FlashMultiHeadAttention(
98+
d_model, n_heads, bias, qk_layernorm=qk_layernorm
99+
)
100+
else:
101+
self.attn = MultiHeadAttention(
102+
d_model, n_heads, bias, qk_layernorm=qk_layernorm
103+
)
95104
self.use_geom_attn = use_geom_attn
96105
if self.use_geom_attn:
97106
if v_heads is None:

esm/layers/rotary.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@
2525
import torch
2626
from einops import rearrange, repeat
2727

28+
try:
29+
from flash_attn.ops.triton.rotary import ( # type:ignore
30+
apply_rotary as apply_triton_rotary,
31+
)
32+
except ImportError:
33+
apply_triton_rotary = None
34+
2835

2936
def rotate_half(x, interleaved=False):
3037
if not interleaved:
@@ -219,3 +226,36 @@ def forward(
219226
) # type: ignore
220227
else:
221228
assert False
229+
230+
231+
class TritonRotaryEmbedding(RotaryEmbedding):
232+
def forward(self, qkv: torch.Tensor, cu_seqlens, max_seqlen) -> torch.Tensor:
233+
"""
234+
qkv: (n, 3, nheads, headdim)
235+
cu_seqlens: cumulative sequence lengths
236+
max_seqlen: max sequence length
237+
"""
238+
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
239+
assert self._cos_cached is not None
240+
assert self._sin_cached is not None
241+
242+
assert apply_triton_rotary is not None
243+
# In-place modification
244+
apply_triton_rotary(
245+
qkv[:, 0],
246+
self._cos_cached,
247+
self._sin_cached,
248+
cu_seqlens=cu_seqlens,
249+
max_seqlen=max_seqlen,
250+
inplace=True,
251+
)
252+
apply_triton_rotary(
253+
qkv[:, 1],
254+
self._cos_cached,
255+
self._sin_cached,
256+
cu_seqlens=cu_seqlens,
257+
max_seqlen=max_seqlen,
258+
inplace=True,
259+
)
260+
261+
return qkv

esm/layers/transformer_stack.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def __init__(
3636
qk_layernorm: bool = True,
3737
ffn_type: str = "swiglu", # swiglu | gelu
3838
expansion_ratio: float = 8 / 3,
39+
use_flash_attn: bool = False,
3940
):
4041
super().__init__()
4142
self.blocks = nn.ModuleList(
@@ -45,6 +46,7 @@ def __init__(
4546
n_heads,
4647
v_heads=v_heads,
4748
use_geom_attn=i < n_layers_geom,
49+
use_flash_attn=use_flash_attn,
4850
residue_scaling_factor=(
4951
math.sqrt(n_layers / 36) if scale_residue else 1.0
5052
),
@@ -66,7 +68,7 @@ def forward(
6668
affine: Affine3D | None = None,
6769
affine_mask: torch.Tensor | None = None,
6870
chain_id: torch.Tensor | None = None,
69-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
71+
) -> tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]:
7072
"""
7173
Forward pass of the TransformerStack.
7274
@@ -89,5 +91,4 @@ def forward(
8991
for block in self.blocks:
9092
x = block(x, sequence_id, affine, affine_mask, chain_id)
9193
hiddens.append(x)
92-
hiddens = torch.stack(hiddens, dim=0)
9394
return self.norm(x), x, hiddens

esm/models/esmc.py

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,15 @@
77
import torch.nn as nn
88
from attr import dataclass
99

10+
try:
11+
from flash_attn.bert_padding import pad_input, unpad_input # type:ignore
12+
13+
is_flash_attn_available = True
14+
except ImportError:
15+
pad_input = None
16+
unpad_input = None
17+
is_flash_attn_available = False
18+
1019
from esm.layers.regression_head import RegressionHead
1120
from esm.layers.transformer_stack import TransformerStack
1221
from esm.sdk.api import (
@@ -43,13 +52,26 @@ class ESMC(nn.Module, ESMCInferenceClient):
4352
"""
4453

4554
def __init__(
46-
self, d_model: int, n_heads: int, n_layers: int, tokenizer: EsmSequenceTokenizer
55+
self,
56+
d_model: int,
57+
n_heads: int,
58+
n_layers: int,
59+
tokenizer: EsmSequenceTokenizer,
60+
use_flash_attn: bool = True,
4761
):
4862
super().__init__()
4963
self.embed = nn.Embedding(64, d_model)
64+
65+
self._use_flash_attn = is_flash_attn_available and use_flash_attn
5066
self.transformer = TransformerStack(
51-
d_model, n_heads, None, n_layers, n_layers_geom=0
67+
d_model,
68+
n_heads,
69+
None,
70+
n_layers,
71+
n_layers_geom=0,
72+
use_flash_attn=self._use_flash_attn,
5273
)
74+
5375
self.sequence_head = RegressionHead(d_model, 64)
5476
self.tokenizer = tokenizer
5577

@@ -109,10 +131,41 @@ def forward(
109131
110132
"""
111133
if sequence_id is None:
112-
sequence_id = sequence_tokens == self.tokenizer.pad_token_id
134+
# For EMSC, a boolean mask is created in place of sequence_id if not specified.
135+
sequence_id = sequence_tokens != self.tokenizer.pad_token_id
113136

114137
x = self.embed(sequence_tokens)
138+
139+
B, L = x.shape[:2]
140+
141+
# If sequence_id looks like a mask.
142+
if self._use_flash_attn:
143+
assert (
144+
sequence_id.dtype == torch.bool
145+
), "sequence_id must be a boolean mask if Flash Attention is used"
146+
assert sequence_id.shape == (B, L)
147+
assert unpad_input is not None
148+
x, indices, _, _, _ = unpad_input( # type: ignore
149+
x, sequence_id
150+
)
151+
else:
152+
indices = None
153+
115154
x, _, hiddens = self.transformer(x, sequence_id=sequence_id)
155+
156+
if self._use_flash_attn:
157+
assert indices is not None
158+
assert pad_input is not None
159+
x = pad_input(x, indices, B, L) # Back to [B, L, D]
160+
hiddens = [
161+
# Back to [[B, L, D], ...]
162+
pad_input(h, indices, B, L)
163+
for h in hiddens
164+
]
165+
166+
# Stack hidden states into a [n_layers, B, L, D] matrix.
167+
hiddens = torch.stack(hiddens, dim=0) # type: ignore
168+
116169
sequence_logits = self.sequence_head(x)
117170
output = ESMCOutput(
118171
sequence_logits=sequence_logits, embeddings=x, hidden_states=hiddens
@@ -161,4 +214,5 @@ def logits(
161214
sequence=output.sequence_logits if config.sequence else None
162215
),
163216
embeddings=output.embeddings if config.return_embeddings else None,
217+
hidden_states=output.hidden_states if config.return_hidden_states else None,
164218
)

esm/pretrained.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,14 @@ def ESM3_function_decoder_v0(device: torch.device | str = "cpu"):
6262
return model
6363

6464

65-
def ESMC_300M_202412(device: torch.device | str = "cpu"):
65+
def ESMC_300M_202412(device: torch.device | str = "cpu", use_flash_attn: bool = True):
6666
with torch.device(device):
6767
model = ESMC(
68-
d_model=960, n_heads=15, n_layers=30, tokenizer=get_esmc_model_tokenizers()
68+
d_model=960,
69+
n_heads=15,
70+
n_layers=30,
71+
tokenizer=get_esmc_model_tokenizers(),
72+
use_flash_attn=use_flash_attn,
6973
).eval()
7074
state_dict = torch.load(
7175
data_root("esmc-300") / "data/weights/esmc_300m_2024_12_v0.pth",
@@ -76,10 +80,14 @@ def ESMC_300M_202412(device: torch.device | str = "cpu"):
7680
return model
7781

7882

79-
def ESMC_600M_202412(device: torch.device | str = "cpu"):
83+
def ESMC_600M_202412(device: torch.device | str = "cpu", use_flash_attn: bool = True):
8084
with torch.device(device):
8185
model = ESMC(
82-
d_model=1152, n_heads=18, n_layers=36, tokenizer=get_esmc_model_tokenizers()
86+
d_model=1152,
87+
n_heads=18,
88+
n_layers=36,
89+
tokenizer=get_esmc_model_tokenizers(),
90+
use_flash_attn=use_flash_attn,
8391
).eval()
8492
state_dict = torch.load(
8593
data_root("esmc-600") / "data/weights/esmc_600m_2024_12_v0.pth",

0 commit comments

Comments
 (0)