Skip to content

Commit 73cd13d

Browse files
authored
Flatten esm-2 model package (#1464)
As we're adding the ESM-C model, we're going to need to install the `esm` package in the devcontainer for running esm-c model tests. This will conflict with our `esm` pseudo-package I created for the original ESM-2 model folder. This PR just flattens the model into a single requirements.txt folder like llama3 and the other, newer models <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **New Features** * Added comprehensive export workflow for converting HuggingFace ESM-2 checkpoints to Transformer Engine format, including parameter count formatting and verification. * **Chores** * Reorganized project structure and dependencies from pyproject.toml to modular configuration files. * Updated import paths and development setup instructions throughout codebase. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 6209b35 commit 73cd13d

24 files changed

+173
-197
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
extend = "../.ruff.toml"

bionemo-recipes/models/esm2/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@ FROM nvcr.io/nvidia/pytorch:26.01-py3
22
WORKDIR /workspace/bionemo
33
COPY . .
44
RUN --mount=type=cache,target=/root/.cache/pip \
5-
PIP_CONSTRAINT= pip install -e .
5+
PIP_CONSTRAINT= pip install -r requirements.txt

bionemo-recipes/models/esm2/README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ Hugging Face Transformers format for sharing and deployment. The workflow involv
8080
```python
8181
from transformers import AutoModelForMaskedLM
8282

83-
from esm.convert import convert_esm_hf_to_te
83+
from convert import convert_esm_hf_to_te
8484

8585
hf_model = AutoModelForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D")
8686
te_model = convert_esm_hf_to_te(hf_model)
@@ -92,8 +92,8 @@ This loads the pre-trained ESM2 model that will serve as our reference for compa
9292
### Converting from TE back to HF Transformers
9393

9494
```python
95-
from esm.convert import convert_esm_te_to_hf
96-
from esm.modeling_esm_te import NVEsmForMaskedLM
95+
from convert import convert_esm_te_to_hf
96+
from modeling_esm_te import NVEsmForMaskedLM
9797

9898
te_model = NVEsmForMaskedLM.from_pretrained("/path/to/te_checkpoint")
9999
hf_model = convert_esm_te_to_hf(te_model)
@@ -130,8 +130,8 @@ To run tests locally, run `recipes_local_test.py` from the repository root with
130130
### Development container
131131

132132
To use the provided devcontainer, use "Dev Containers: Reopen in Container" from the VSCode menu, and choose the
133-
"BioNeMo Recipes Dev Container" option. To run the tests inside the container, first install the model package in
134-
editable mode with `pip install -e .`, then run `pytest -v .` in the model directory.
133+
"BioNeMo Recipes Dev Container" option. To run the tests inside the container, first install the dependencies with
134+
`pip install -r requirements.txt`, then run `pytest -v .` in the model directory.
135135

136136
### Deploying converted checkpoints to HuggingFace Hub
137137

File renamed without changes.

bionemo-recipes/models/esm2/src/esm/convert.py renamed to bionemo-recipes/models/esm2/convert.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
from torch import nn
1919
from transformers import EsmConfig, EsmForMaskedLM
2020

21-
import esm.state as io
22-
from esm.modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM
21+
import state as io
22+
from modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM
2323

2424

2525
mapping = {

bionemo-recipes/models/esm2/export.py

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,27 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import gc
17+
import json
18+
import shutil
1619
from pathlib import Path
1720

18-
from esm.export import export_hf_checkpoint
21+
import torch
22+
from jinja2 import Template
23+
from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer
24+
25+
from convert import convert_esm_hf_to_te
26+
from modeling_esm_te import AUTO_MAP
27+
28+
29+
BENCHMARK_RESULTS = {
30+
"esm2_t6_8M_UR50D": {"CAMEO": 0.48, "CASP14": 0.37},
31+
"esm2_t12_35M_UR50D": {"CAMEO": 0.56, "CASP14": 0.41},
32+
"esm2_t30_150M_UR50D": {"CAMEO": 0.65, "CASP14": 0.49},
33+
"esm2_t33_650M_UR50D": {"CAMEO": 0.70, "CASP14": 0.51},
34+
"esm2_t36_3B_UR50D": {"CAMEO": 0.72, "CASP14": 0.52},
35+
"esm2_t48_15B_UR50D": {"CAMEO": 0.72, "CASP14": 0.55},
36+
}
1937

2038

2139
ESM_TAGS = [
@@ -28,6 +46,85 @@
2846
]
2947

3048

49+
def format_parameter_count(num_params: int, sig: int = 1) -> str:
50+
"""Format parameter count in scientific notation (e.g., 6.5 x 10^8).
51+
52+
Args:
53+
num_params: Total number of parameters
54+
sig: Number of digits to include after the decimal point
55+
56+
Returns:
57+
Formatted string in scientific notation
58+
"""
59+
s = f"{num_params:.{sig}e}"
60+
base, exp = s.split("e")
61+
return f"{base} x 10^{int(exp)}"
62+
63+
64+
def export_hf_checkpoint(tag: str, export_path: Path):
65+
"""Export a Hugging Face checkpoint to a Transformer Engine checkpoint.
66+
67+
Args:
68+
tag: The tag of the checkpoint to export.
69+
export_path: The parent path to export the checkpoint to.
70+
"""
71+
model_hf_masked_lm = AutoModelForMaskedLM.from_pretrained(f"facebook/{tag}")
72+
model_hf = AutoModel.from_pretrained(f"facebook/{tag}")
73+
model_hf_masked_lm.esm.pooler = model_hf.pooler
74+
model_te = convert_esm_hf_to_te(model_hf_masked_lm)
75+
model_te.save_pretrained(export_path / tag)
76+
77+
tokenizer = AutoTokenizer.from_pretrained("esm_fast_tokenizer") # Use our PreTrainedTokenizerFast implementation.
78+
tokenizer.save_pretrained(export_path / tag)
79+
80+
# Patch the config
81+
with open(export_path / tag / "config.json", "r") as f:
82+
config = json.load(f)
83+
84+
config["auto_map"] = AUTO_MAP
85+
86+
with open(export_path / tag / "config.json", "w") as f:
87+
json.dump(config, f, indent=2, sort_keys=True)
88+
89+
shutil.copy("modeling_esm_te.py", export_path / tag / "esm_nv.py")
90+
91+
# Calculate model parameters and render README template
92+
num_params = sum(p.numel() for p in model_te.parameters())
93+
formatted_params = format_parameter_count(num_params)
94+
95+
# Read and render the template
96+
with open("model_readme.template", "r", encoding="utf-8") as f:
97+
template_content = f.read()
98+
99+
template = Template(template_content)
100+
rendered_readme = template.render(
101+
num_params=formatted_params,
102+
model_tag=tag,
103+
cameo_score=BENCHMARK_RESULTS[tag]["CAMEO"],
104+
casp14_score=BENCHMARK_RESULTS[tag]["CASP14"],
105+
)
106+
107+
# Write the rendered README
108+
with open(export_path / tag / "README.md", "w") as f:
109+
f.write(rendered_readme)
110+
111+
shutil.copy("LICENSE", export_path / tag / "LICENSE")
112+
113+
del model_hf, model_te, model_hf_masked_lm
114+
gc.collect()
115+
torch.cuda.empty_cache()
116+
117+
# Smoke test that the model can be loaded.
118+
model_te = AutoModelForMaskedLM.from_pretrained(
119+
export_path / tag,
120+
dtype=torch.bfloat16,
121+
trust_remote_code=True,
122+
)
123+
del model_te
124+
gc.collect()
125+
torch.cuda.empty_cache()
126+
127+
31128
def main():
32129
"""Export the ESM2 models from Hugging Face to the Transformer Engine format."""
33130
# TODO (peter): maybe add a way to specify the model to export or option to export all models?
File renamed without changes.

bionemo-recipes/models/esm2/pyproject.toml

Lines changed: 0 additions & 34 deletions
This file was deleted.
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
accelerate
2+
datasets
3+
hydra-core
4+
jinja2
5+
megatron-fsdp
6+
omegaconf
7+
peft
8+
torch
9+
torchao!=0.14.0
10+
transformer_engine[pytorch]
11+
transformers

bionemo-recipes/models/esm2/src/esm/__init__.py

Lines changed: 0 additions & 14 deletions
This file was deleted.

0 commit comments

Comments
 (0)