From 24c7ce915aaff9660a99d1e9a024767d4b7af37a Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Tue, 6 Jan 2026 13:24:16 -0800 Subject: [PATCH 1/6] adds modeling te to disable fp8 on first and last layers in esm2 Signed-off-by: Jonathan Mitchell --- .../recipes/esm2_native_te/modeling_esm_te.py | 703 ++++++++++++++++++ .../recipes/esm2_native_te/train_fsdp2.py | 8 +- ci/scripts/check_copied_files.py | 1 + 3 files changed, 710 insertions(+), 2 deletions(-) create mode 100644 bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py diff --git a/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py b/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py new file mode 100644 index 0000000000..144921bb06 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py @@ -0,0 +1,703 @@ +# coding=utf-8 +# noqa: license-check +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved. +# Copyright 2025 NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""TransformerEngine-optimized ESM model. + +Adapted from `modeling_esm.py` in huggingface/transformers. +""" + +from typing import Literal, Optional, Unpack + +# TODO: put import guard around transformer_engine here, with an informative error message around +# installation and the nvidia docker container. +import torch +import transformer_engine.pytorch +from torch import nn +from torch.nn import CrossEntropyLoss +from transformer_engine.pytorch.attention.rope import RotaryPositionEmbedding +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPooling, + MaskedLMOutput, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.models.esm.configuration_esm import EsmConfig +from transformers.models.esm.modeling_esm import EsmPooler +from transformers.utils import logging +from transformers.utils.generic import TransformersKwargs +from contextlib import nullcontext + +logger = logging.get_logger(__name__) + +# Dictionary that gets inserted into config.json to map Auto** classes to our TE-optimized model classes defined below. +# These should be prefixed with esm_nv., since we name the file esm_nv.py in our exported checkpoints. +AUTO_MAP = { + "AutoConfig": "esm_nv.NVEsmConfig", + "AutoModel": "esm_nv.NVEsmModel", + "AutoModelForMaskedLM": "esm_nv.NVEsmForMaskedLM", + "AutoModelForTokenClassification": "esm_nv.NVEsmForTokenClassification", +} + + +class NVEsmConfig(EsmConfig): + """NVEsmConfig is a configuration for the NVEsm model.""" + + model_type: str = "nv_esm" + + def __init__( + self, + qkv_weight_interleaved: bool = True, + encoder_activation: str = "gelu", + attn_input_format: Literal["bshd", "thd"] = "bshd", + fuse_qkv_params: bool = True, + micro_batch_size: Optional[int] = None, + max_seq_length: Optional[int] = None, + padded_vocab_size: Optional[int] = 64, + attn_mask_type: str = "padding", + **kwargs, + ): + """Initialize the NVEsmConfig with additional TE-related config options. + + Args: + qkv_weight_interleaved: Whether to interleave the qkv weights. If set to `False`, the + QKV weight is interpreted as a concatenation of query, key, and value weights along + the `0th` dimension. The default interpretation is that the individual `q`, `k`, and + `v` weights for each attention head are interleaved. This parameter is set to `False` + when using :attr:`fuse_qkv_params=False`. + encoder_activation: The activation function to use in the encoder. + attn_input_format: The input format to use for the attention. This controls + whether the dimensions of the intermediate hidden states is 'batch first' + ('bshd') or 'sequence first' ('sbhd'). `s` stands for the sequence length, + `b` batch size, `h` the number of heads, `d` head size. Note that these + formats are very closely related to the `qkv_format` in the + `MultiHeadAttention` and `DotProductAttention` modules. + fuse_qkv_params: Whether to fuse the qkv parameters. If set to `True`, + `TransformerLayer` module exposes a single fused parameter for query-key-value. + This enables optimizations such as QKV fusion without concatentations/splits and + also enables the argument `fuse_wgrad_accumulation`. + micro_batch_size: The micro batch size to use for the attention. This is needed for + JIT Warmup, a technique where jit fused functions are warmed up before training to + ensure same kernels are used for forward propogation and activation recompute phase. + max_seq_length: The maximum sequence length to use for the attention. This is needed for + JIT Warmup, a technique where jit fused functions are warmed up before training to + ensure same kernels are used for forward propogation and activation recompute phase. + padded_vocab_size: The padded vocabulary size to support FP8. If not provided, defaults + to vocab_size. Must be greater than or equal to vocab_size. + attn_mask_type: The type of attention mask to use. + **kwargs: Additional config options to pass to EsmConfig. + """ + super().__init__(**kwargs) + # Additional TE-related config options. + self.qkv_weight_interleaved = qkv_weight_interleaved + self.encoder_activation = encoder_activation + self.attn_input_format = attn_input_format + self.fuse_qkv_params = fuse_qkv_params + self.micro_batch_size = micro_batch_size + self.max_seq_length = max_seq_length + self.attn_mask_type = attn_mask_type + + # Set padded_vocab_size with default fallback to vocab_size + self.padded_vocab_size = padded_vocab_size if padded_vocab_size is not None else self.vocab_size + + # Ensure padded_vocab_size is at least as large as vocab_size + if self.padded_vocab_size is not None and self.vocab_size is not None: + assert self.padded_vocab_size >= self.vocab_size, ( + f"padded_vocab_size ({self.padded_vocab_size}) must be greater than or equal to vocab_size ({self.vocab_size})" + ) + + +class NVEsmEncoder(nn.Module): + """NVEsmEncoder is a TransformerEngine-optimized ESM encoder.""" + + def __init__(self, config: NVEsmConfig): + """Initialize a NVEsmEncoder. + + Args: + config (NVEsmConfig): The configuration of the model. + """ + super().__init__() + self.config = config + self.layers = nn.ModuleList( + [ + transformer_engine.pytorch.TransformerLayer( + hidden_size=config.hidden_size, + ffn_hidden_size=config.intermediate_size, + num_attention_heads=config.num_attention_heads, + layernorm_epsilon=config.layer_norm_eps, + hidden_dropout=config.hidden_dropout_prob, + attention_dropout=config.attention_probs_dropout_prob, + qkv_weight_interleaved=config.qkv_weight_interleaved, + layer_number=i + 1, + layer_type="encoder", + self_attn_mask_type=config.attn_mask_type, + activation=config.encoder_activation, + attn_input_format=config.attn_input_format, + seq_length=config.max_seq_length, + micro_batch_size=config.micro_batch_size, + num_gqa_groups=config.num_attention_heads, + fuse_qkv_params=config.fuse_qkv_params, + params_dtype=config.dtype, + window_size=(-1, -1), + ) + for i in range(config.num_hidden_layers) + ] + ) + self.emb_layer_norm_after = transformer_engine.pytorch.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps, params_dtype=config.dtype + ) + if config.position_embedding_type == "rotary": + self.rotary_embeddings = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ): + """Forward pass of the NVEsmEncoder. + + Args: + hidden_states (torch.Tensor): The hidden states. + attention_mask (torch.Tensor): The attention mask. + **kwargs: Additional arguments, see TransformersKwargs for more details. + """ + all_hidden_states: tuple[torch.Tensor, ...] = () + has_thd_input = [ + x is not None + for x in [ + kwargs.get("cu_seq_lens_q", None), + kwargs.get("cu_seq_lens_k", None), + kwargs.get("max_length_q", None), + kwargs.get("max_length_k", None), + ] + ] + + if self.config.attn_input_format == "thd": + if not all(has_thd_input): + raise ValueError( + "cu_seq_lens_q, cu_seq_lens_k, max_length_q, and max_length_k must be provided when using THD inputs." + ) + assert hidden_states.dim() == 3 and hidden_states.size(0) == 1, ( + "THD expects embeddings shaped [1, total_tokens, hidden_size]." + ) + hidden_states = hidden_states.squeeze(0) + attention_mask = None + + elif self.config.attn_input_format == "bshd" and any(has_thd_input): + raise ValueError( + "cu_seq_lens_q, cu_seq_lens_k, max_length_q, and max_length_k are not allowed when using BSHD inputs." + ) + + # Ensure that rotary embeddings are computed with at a higher precision outside the torch autocast context. + with torch.autocast(device_type="cuda", enabled=False): + if self.config.position_embedding_type == "rotary": + if self.config.attn_input_format == "bshd": + te_rope_emb = self.rotary_embeddings(max_seq_len=hidden_states.shape[1]) + elif self.config.attn_input_format == "thd": + te_rope_emb = self.rotary_embeddings( + max_seq_len=kwargs["cu_seq_lens_q_padded"][-1] + if "cu_seq_lens_q_padded" in kwargs + else kwargs["cu_seq_lens_q"][-1] + ) + te_rope_emb = te_rope_emb.to(hidden_states.device, non_blocking=True) + + + for layer_module in self.layers: + if kwargs.get("output_hidden_states", False): + all_hidden_states = (*all_hidden_states, hidden_states) + + if layer_module in {self.layers[0], self.layers[-1]}: + fp8_context = transformer_engine.pytorch.fp8_autocast(enabled=False) + else: + fp8_context = nullcontext() + + with fp8_context: + hidden_states = layer_module( + hidden_states, + attention_mask, + rotary_pos_emb=te_rope_emb, + cu_seqlens_q=kwargs.get("cu_seq_lens_q", None), + cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None), + cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None), + cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None), + max_seqlen_q=kwargs.get("max_length_q", None), + max_seqlen_kv=kwargs.get("max_length_k", None), + pad_between_seqs=kwargs.get("pad_between_seqs", None), + ) + + hidden_states = self.emb_layer_norm_after(hidden_states) + + if kwargs.get("output_hidden_states", False): + all_hidden_states = (*all_hidden_states, hidden_states) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states if all_hidden_states else None, + ) + + +class NVEsmPreTrainedModel(PreTrainedModel): + """An abstract class to handle weights initialization and pretrained model loading.""" + + config_class = NVEsmConfig + base_model_prefix = "esm" + supports_gradient_checkpointing = False + accepts_loss_kwargs = False + _no_split_modules = ( + "TransformerLayer", + "EsmEmbeddings", + ) + + def _init_weights(self, module: nn.Module): + """Initialize model weights. + + This method ensures that models with randomly-initialized weights get the correct initial value distribution, + which can be critical for training stability. We also call this method directly when using meta-device init, as + the `to_empty` method does not initialize the weights. While the base Transformers model has a similar method, + we need to extend it to handle TE-specific modules. + + Args: + module (nn.Module): The module to initialize the weights for. + """ + if isinstance( + module, (nn.Linear, transformer_engine.pytorch.Linear, transformer_engine.pytorch.LayerNormLinear) + ): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + if isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + if isinstance(module, (nn.LayerNorm, transformer_engine.pytorch.LayerNorm)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, transformer_engine.pytorch.LayerNormLinear): + if module.layer_norm_bias is not None: + module.layer_norm_bias.data.zero_() + module.layer_norm_weight.data.fill_(1.0) + if module.layer_norm_bias is not None: + module.layer_norm_bias.data.zero_() + if isinstance(module, transformer_engine.pytorch.LayerNormMLP): + if module.layer_norm_bias is not None: + module.layer_norm_bias.data.zero_() + module.layer_norm_weight.data.fill_(1.0) + if hasattr(module, "fc1_weight") and module.fc1_weight is not None: + module.fc1_weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if hasattr(module, "fc2_weight") and module.fc2_weight is not None: + module.fc2_weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if hasattr(module, "fc1_bias") and module.fc1_bias is not None and module.fc1_bias.numel() > 0: + module.fc1_bias.data.zero_() + if hasattr(module, "fc2_bias") and module.fc2_bias is not None and module.fc2_bias.numel() > 0: + module.fc2_bias.data.zero_() + if isinstance(module, RotaryPositionEmbedding) and hasattr(module, "inv_freq"): + # When we initialize the model with `to_empty`, the `inv_freq` attribute is not initialized, so we need to + # re-initialize it here with the correct values. + module.inv_freq = RotaryPositionEmbedding( + self.config.hidden_size // self.config.num_attention_heads + ).inv_freq.to(module.inv_freq.device) + + @classmethod + def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool): + """Override the default get_init_context method to allow for fp8 model initialization.""" + return [] + + +class NVEsmModel(NVEsmPreTrainedModel): + """The ESM Encoder-only protein language model. + + This model uses NVDIA's TransformerEngine to optimize attention layer training and inference. + """ + + def __init__(self, config: NVEsmConfig, add_pooling_layer: bool = True): + """Initialize a NVEsmModel. + + Args: + config (NVEsmConfig): The configuration of the model. + add_pooling_layer (bool): Whether to add a pooling layer. + """ + super().__init__(config) + self.config = config + + # Ensure pad_token_id is set properly, defaulting to 0 if not specified + if not hasattr(config, "pad_token_id") or config.pad_token_id is None: + config.pad_token_id = 0 + self.embeddings = NVEsmEmbeddings(config) + self.encoder = NVEsmEncoder(config) + self.pooler = EsmPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + """Get the input embeddings of the model.""" + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value: torch.Tensor): + """Set the input embeddings of the model. + + Args: + value (torch.Tensor): The input embeddings. + """ + self.embeddings.word_embeddings = value + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPooling: + """Forward pass of the NVEsmModel. + + Args: + input_ids (torch.Tensor): The input ids. + attention_mask (torch.Tensor): The attention mask. + position_ids (torch.Tensor): The position ids. + inputs_embeds (torch.Tensor): The input embeddings. + **kwargs: Additional arguments, see TransformersKwargs for more details. + + Returns: + BaseModelOutputWithPooling: The output of the model. + """ + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length)), device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # TE expects a boolean attention mask, where 1s are masked and 0s are not masked + extended_attention_mask = extended_attention_mask < -1 + + embedding_output = self.embeddings( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + **kwargs, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + **kwargs, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + ) + + +class NVEsmForMaskedLM(NVEsmPreTrainedModel): + """NVEsmForMaskedLM is a TransformerEngine-optimized ESM model for masked language modeling.""" + + _tied_weights_keys = ("lm_head.decoder.weight",) + + def __init__(self, config: NVEsmConfig): + """Initialize a NVEsmForMaskedLM. + + Args: + config (NVEsmConfig): The configuration of the model. + """ + super().__init__(config) + + if config.is_decoder: + logger.warning( + "If you want to use `EsmForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.esm = NVEsmModel(config, add_pooling_layer=False) + self.lm_head = NVEsmLMHead(config) + + self.init_weights() + self.post_init() + + def get_output_embeddings(self): + """Get the output embeddings of the model.""" + return self.lm_head.decoder + + def set_output_embeddings(self, new_embeddings): + """Set the output embeddings of the model.""" + self.lm_head.decoder = new_embeddings + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MaskedLMOutput: + """Forward pass of the NVEsmForMaskedLM. + + Args: + input_ids (torch.LongTensor): The input ids. + attention_mask (torch.Tensor): The attention mask. + position_ids (torch.LongTensor): The position ids. + inputs_embeds (torch.FloatTensor): The input embeddings. + labels (torch.LongTensor): The labels. + **kwargs: Additional arguments, see TransformersKwargs for more details. + + Returns: + MaskedLMOutput: The output of the model. + """ + outputs = self.esm( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + sequence_output = outputs[0] + prediction_scores = self.lm_head(sequence_output) + + # Truncate logits back to original vocab_size if padding was used + if self.config.padded_vocab_size != self.config.vocab_size: + prediction_scores = prediction_scores[..., : self.config.vocab_size] + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), + labels.to(prediction_scores.device).view(-1), + ) + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + ) + + +class NVEsmLMHead(nn.Module): + """ESM Head for masked language modeling using TransformerEngine.""" + + def __init__(self, config: NVEsmConfig): + """Initialize a NVEsmLMHead. + + Args: + config (NVEsmConfig): The configuration of the model. + """ + super().__init__() + self.dense = transformer_engine.pytorch.Linear( + config.hidden_size, + config.hidden_size, + params_dtype=config.dtype, + ) + + self.decoder = transformer_engine.pytorch.LayerNormLinear( + config.hidden_size, + config.padded_vocab_size if config.padded_vocab_size is not None else config.vocab_size, + bias=True, + eps=config.layer_norm_eps, + params_dtype=config.dtype, + ) + + def forward(self, features, **kwargs): + """Forward pass of the NVEsmLMHead. + + Args: + features (torch.Tensor): The features. + **kwargs: Additional arguments. + """ + x = self.dense(features) + x = torch.nn.functional.gelu(x) + x = self.decoder(x) + return x + + +class NVEsmEmbeddings(nn.Module): + """Modified version of EsmEmbeddings to support THD inputs.""" + + def __init__(self, config): + """Initialize a NVEsmEmbeddings.""" + super().__init__() + self.word_embeddings = nn.Embedding( + config.padded_vocab_size, + config.hidden_size, + padding_idx=config.pad_token_id, + dtype=config.dtype, + ) + + self.layer_norm = ( + transformer_engine.pytorch.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + if config.emb_layer_norm_before + else None + ) + + if config.position_embedding_type != "rotary": + raise ValueError( + "The TE-accelerated ESM-2 model only supports rotary position embeddings, received " + f"{config.position_embedding_type}" + ) + + self.padding_idx = config.pad_token_id + self.token_dropout = config.token_dropout + self.mask_token_id = config.mask_token_id + + def forward( + self, + input_ids=None, + attention_mask=None, + inputs_embeds=None, + **kwargs: Unpack[TransformersKwargs], + ): + """Forward pass of the NVEsmEmbeddings.""" + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + # Note that if we want to support ESM-1 (not 1b!) in future then we need to support an + # embedding_scale factor here. + embeddings = inputs_embeds + + if ( + kwargs.get("cu_seq_lens_q") is not None + and kwargs.get("cu_seq_lens_k") is not None + and kwargs.get("max_length_q") is not None + and kwargs.get("max_length_k") is not None + ): + using_thd = True + attention_mask = None + else: + using_thd = False + + # Matt: ESM has the option to handle masking in MLM in a slightly unusual way. If the token_dropout + # flag is False then it is handled in the same was as BERT/RoBERTa. If it is set to True, however, + # masked tokens are treated as if they were selected for input dropout and zeroed out. + # This "mask-dropout" is compensated for when masked tokens are not present, by scaling embeddings by + # a factor of (fraction of unmasked tokens during training) / (fraction of unmasked tokens in sample). + # This is analogous to the way that dropout layers scale down outputs during evaluation when not + # actually dropping out values (or, equivalently, scale up their un-dropped outputs in training). + if self.token_dropout and input_ids is not None: + embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0.0) + mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all ESM model training runs + + if not using_thd: + # BSHD token dropout correction + src_lengths = attention_mask.sum(-1) if attention_mask is not None else input_ids.shape[1] + n_masked_per_seq = (input_ids == self.mask_token_id).sum(-1).float() + mask_ratio_observed = n_masked_per_seq / src_lengths + scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed) + embeddings = (embeddings * scale_factor[:, None, None]).to(embeddings.dtype) + + else: + src_lengths = torch.diff(kwargs["cu_seq_lens_q"]) + # We need to find the number of masked tokens in each sequence in the padded batch. + is_masked = (input_ids == self.mask_token_id).squeeze(0) + n_masked_per_seq = torch.nested.nested_tensor_from_jagged( + is_masked, offsets=kwargs["cu_seq_lens_q"] + ).sum(1) + mask_ratio_observed = n_masked_per_seq.float() / src_lengths + scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed) + reshaped_scale_factor = torch.repeat_interleave(scale_factor, src_lengths, dim=0) + embeddings = (embeddings * reshaped_scale_factor.unsqueeze(-1)).to(embeddings.dtype) + + if self.layer_norm is not None: + embeddings = self.layer_norm(embeddings) + + if attention_mask is not None: + embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype) + + return embeddings + + +class NVEsmForTokenClassification(NVEsmPreTrainedModel): + """Adds a token classification head to the model. + + Adapted from EsmForTokenClassification in Hugging Face Transformers `modeling_esm.py`. + """ + + def __init__(self, config): + """Initialize NVEsmForTokenClassification.""" + super().__init__(config) + self.num_labels = config.num_labels + + self.esm = NVEsmModel(config, add_pooling_layer=False) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = transformer_engine.pytorch.Linear( + config.hidden_size, config.num_labels, params_dtype=config.dtype + ) + + self.init_weights() + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> TokenClassifierOutput: + """Forward pass for the token classification head. + + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + outputs = self.esm( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + + labels = labels.to(logits.device) + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py index 293eafe84d..eaeb4f4510 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py @@ -27,6 +27,8 @@ from transformer_engine.common.recipe import Format from transformers import AutoConfig, AutoModelForMaskedLM +from modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM + # This import seems to be needed with meta device init and AutoModel.from_config from transformers.models.esm.modeling_esm import EsmForMaskedLM # noqa: F401 @@ -68,7 +70,9 @@ def main(args: DictConfig) -> float | None: ) # Create an empty ESM-2 model with a masked language model head, e.g. "nvidia/esm2_t6_8M_UR50D". - config = AutoConfig.from_pretrained(args.model_tag, trust_remote_code=True, dtype=torch.bfloat16) + # change model_tag to local tag. + config = NVEsmConfig.from_pretrained(args.model_tag, dtype=torch.bfloat16) + # If we're using sequence packing with TE layers, we need to pass the `attn_input_format` argument. if args.use_sequence_packing: config.attn_input_format = "thd" @@ -80,7 +84,7 @@ def main(args: DictConfig) -> float | None: torch.device("meta") if args.use_meta_device else nullcontext(), transformer_engine.pytorch.fp8_model_init(recipe=fp8_recipe, **args.fp8_config.fp8_model_init_kwargs), ): - model = AutoModelForMaskedLM.from_config(config, trust_remote_code=True) + model = NVEsmForMaskedLM(config) logger.info("Initialized Model:\n%s", model) diff --git a/ci/scripts/check_copied_files.py b/ci/scripts/check_copied_files.py index 5ea6006186..50b9ddb2ba 100755 --- a/ci/scripts/check_copied_files.py +++ b/ci/scripts/check_copied_files.py @@ -34,6 +34,7 @@ "bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py", "bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py", "bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py", + "bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py", ], "bionemo-recipes/models/esm2/src/esm/collator.py": [ "bionemo-recipes/recipes/esm2_native_te/collator.py", From 4574b54fb6bb6f9873e824cae62eff85845a34e0 Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Wed, 7 Jan 2026 12:34:55 -0800 Subject: [PATCH 2/6] adds bf16 for first 3 layers and last 6 layers Signed-off-by: Jonathan Mitchell --- bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py b/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py index 144921bb06..a41a9a80c5 100644 --- a/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py +++ b/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py @@ -224,7 +224,7 @@ def forward( if kwargs.get("output_hidden_states", False): all_hidden_states = (*all_hidden_states, hidden_states) - if layer_module in {self.layers[0], self.layers[-1]}: + if layer_module in {self.layers[0], self.layers[1], self.layers[2], self.layers[-6], self.layers[-5], self.layers[-4], self.layers[-3], self.layers[-2], self.layers[-1]}: fp8_context = transformer_engine.pytorch.fp8_autocast(enabled=False) else: fp8_context = nullcontext() From 1e1380dbb1e88d24a46ab6a90dc934559b02d8be Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Wed, 7 Jan 2026 14:58:36 -0800 Subject: [PATCH 3/6] adds LM head to bf16 Signed-off-by: Jonathan Mitchell --- .../recipes/esm2_native_te/modeling_esm_te.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py b/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py index a41a9a80c5..42df3809df 100644 --- a/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py +++ b/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py @@ -488,6 +488,8 @@ def forward( **kwargs, ) sequence_output = outputs[0] + # Do the FP8 autocast here. + prediction_scores = self.lm_head(sequence_output) # Truncate logits back to original vocab_size if padding was used @@ -540,9 +542,10 @@ def forward(self, features, **kwargs): features (torch.Tensor): The features. **kwargs: Additional arguments. """ - x = self.dense(features) - x = torch.nn.functional.gelu(x) - x = self.decoder(x) + with transformer_engine.pytorch.fp8_autocast(enabled=False): + x = self.dense(features) + x = torch.nn.functional.gelu(x) + x = self.decoder(x) return x From 0a678f69231da1944ecd6615334c7328ce90e7f5 Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Wed, 7 Jan 2026 15:23:00 -0800 Subject: [PATCH 4/6] adds layernormemb to autocast false so its bf16 Signed-off-by: Jonathan Mitchell --- bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py b/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py index 42df3809df..9dd8d9ec2a 100644 --- a/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py +++ b/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py @@ -243,7 +243,8 @@ def forward( pad_between_seqs=kwargs.get("pad_between_seqs", None), ) - hidden_states = self.emb_layer_norm_after(hidden_states) + with transformer_engine.pytorch.fp8_autocast(enabled=False): + hidden_states = self.emb_layer_norm_after(hidden_states) if kwargs.get("output_hidden_states", False): all_hidden_states = (*all_hidden_states, hidden_states) From b8aefdc3a2680fb8fc81b290bbdad1ce0cc8a62d Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Wed, 7 Jan 2026 17:48:28 -0800 Subject: [PATCH 5/6] bf16 first last and head layers commit before this one worked Signed-off-by: Jonathan Mitchell --- bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py b/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py index 9dd8d9ec2a..6c860f4f39 100644 --- a/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py +++ b/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py @@ -224,7 +224,7 @@ def forward( if kwargs.get("output_hidden_states", False): all_hidden_states = (*all_hidden_states, hidden_states) - if layer_module in {self.layers[0], self.layers[1], self.layers[2], self.layers[-6], self.layers[-5], self.layers[-4], self.layers[-3], self.layers[-2], self.layers[-1]}: + if layer_module in {self.layers[0], self.layers[-1]}: fp8_context = transformer_engine.pytorch.fp8_autocast(enabled=False) else: fp8_context = nullcontext() @@ -243,8 +243,7 @@ def forward( pad_between_seqs=kwargs.get("pad_between_seqs", None), ) - with transformer_engine.pytorch.fp8_autocast(enabled=False): - hidden_states = self.emb_layer_norm_after(hidden_states) + hidden_states = self.emb_layer_norm_after(hidden_states) if kwargs.get("output_hidden_states", False): all_hidden_states = (*all_hidden_states, hidden_states) From aaad421ee3b10dce02f348f76570be5f0bfc0e3f Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Thu, 8 Jan 2026 11:20:17 -0800 Subject: [PATCH 6/6] turns off BF16 for intermediate layers of the transformer and only uses BF16 for the LM Head Signed-off-by: Jonathan Mitchell --- bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py b/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py index 6c860f4f39..934d0166bc 100644 --- a/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py +++ b/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py @@ -224,10 +224,10 @@ def forward( if kwargs.get("output_hidden_states", False): all_hidden_states = (*all_hidden_states, hidden_states) - if layer_module in {self.layers[0], self.layers[-1]}: - fp8_context = transformer_engine.pytorch.fp8_autocast(enabled=False) - else: - fp8_context = nullcontext() + # if layer_module in {self.layers[0], self.layers[-1]}: + # fp8_context = transformer_engine.pytorch.fp8_autocast(enabled=False) + # else: + fp8_context = nullcontext() with fp8_context: hidden_states = layer_module(