Skip to content

Commit 1ad216b

Browse files
authored
[modenbert] fix regression (#39750)
* fix regression * add FA2 test
1 parent 379209b commit 1ad216b

File tree

3 files changed

+40
-20
lines changed

3 files changed

+40
-20
lines changed

src/transformers/models/modernbert/modeling_modernbert.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
# See the License for the specific language governing permissions and
2020
# limitations under the License.
2121

22+
import copy
2223
import math
2324
from contextlib import nullcontext
2425
from typing import Optional, Union
@@ -459,20 +460,21 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None):
459460

460461
if layer_id % config.global_attn_every_n_layers != 0:
461462
self.local_attention = (config.local_attention // 2, config.local_attention // 2)
463+
rope_theta = config.local_rope_theta if config.local_rope_theta is not None else config.global_rope_theta
464+
max_position_embeddings = config.local_attention
462465
else:
463466
self.local_attention = (-1, -1)
464-
465-
max_position_embeddings = config.max_position_embeddings
466-
if self.local_attention != (-1, -1):
467-
rope_theta = config.global_rope_theta if config.local_rope_theta is None else config.local_rope_theta
468-
max_position_embeddings = config.local_attention
467+
max_position_embeddings = config.max_position_embeddings
468+
rope_theta = config.global_rope_theta
469469

470470
if config._attn_implementation == "flash_attention_2":
471471
self.rotary_emb = ModernBertUnpaddedRotaryEmbedding(
472472
dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta
473473
)
474474
else:
475-
self.rotary_emb = ModernBertRotaryEmbedding(config=config)
475+
config_copy = copy.deepcopy(config)
476+
config_copy.rope_theta = rope_theta
477+
self.rotary_emb = ModernBertRotaryEmbedding(config=config_copy)
476478

477479
self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
478480
self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()
@@ -611,7 +613,9 @@ def init_weight(module: nn.Module, std: float):
611613
if module.bias is not None:
612614
module.bias.data.zero_()
613615

614-
def set_attention_implementation(self, attn_implementation: Union[dict, str]):
616+
def _check_and_adjust_attn_implementation(
617+
self, attn_implementation: Optional[str], is_init_check: bool = False
618+
) -> str:
615619
"""
616620
Checks and dispatches to hhe requested attention implementation.
617621
"""
@@ -620,16 +624,17 @@ def set_attention_implementation(self, attn_implementation: Union[dict, str]):
620624
# ModernBert's FA2 implementation correctly handles non-fp16/bf16 dtypes, we don't
621625
# need the FA2 warning for non-fp16/bf16 dtypes so we set fp16 for the FA2 check.
622626

623-
requested_attn_implementation = self._check_attn_implementation(attn_implementation)
624627
try:
625628
attn_implementation = (
626629
"flash_attention_2"
627-
if requested_attn_implementation is None and self._flash_attn_2_can_dispatch()
630+
if attn_implementation is None and self._flash_attn_2_can_dispatch()
628631
else attn_implementation
629632
)
630633
except (ValueError, ImportError):
631634
pass
632-
return super().set_attention_implementation(attn_implementation=attn_implementation)
635+
return super()._check_and_adjust_attn_implementation(
636+
attn_implementation=attn_implementation, is_init_check=is_init_check
637+
)
633638

634639
def _maybe_set_compile(self):
635640
if self.config.reference_compile is False:

src/transformers/models/modernbert/modular_modernbert.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import copy
1617
import math
1718
from contextlib import nullcontext
1819
from typing import Literal, Optional, Union
@@ -659,20 +660,21 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None):
659660

660661
if layer_id % config.global_attn_every_n_layers != 0:
661662
self.local_attention = (config.local_attention // 2, config.local_attention // 2)
663+
rope_theta = config.local_rope_theta if config.local_rope_theta is not None else config.global_rope_theta
664+
max_position_embeddings = config.local_attention
662665
else:
663666
self.local_attention = (-1, -1)
664-
665-
max_position_embeddings = config.max_position_embeddings
666-
if self.local_attention != (-1, -1):
667-
rope_theta = config.global_rope_theta if config.local_rope_theta is None else config.local_rope_theta
668-
max_position_embeddings = config.local_attention
667+
max_position_embeddings = config.max_position_embeddings
668+
rope_theta = config.global_rope_theta
669669

670670
if config._attn_implementation == "flash_attention_2":
671671
self.rotary_emb = ModernBertUnpaddedRotaryEmbedding(
672672
dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta
673673
)
674674
else:
675-
self.rotary_emb = ModernBertRotaryEmbedding(config=config)
675+
config_copy = copy.deepcopy(config)
676+
config_copy.rope_theta = rope_theta
677+
self.rotary_emb = ModernBertRotaryEmbedding(config=config_copy)
676678

677679
self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
678680
self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()
@@ -811,7 +813,9 @@ def init_weight(module: nn.Module, std: float):
811813
if module.bias is not None:
812814
module.bias.data.zero_()
813815

814-
def set_attention_implementation(self, attn_implementation: Union[dict, str]):
816+
def _check_and_adjust_attn_implementation(
817+
self, attn_implementation: Optional[str], is_init_check: bool = False
818+
) -> str:
815819
"""
816820
Checks and dispatches to hhe requested attention implementation.
817821
"""
@@ -820,16 +824,17 @@ def set_attention_implementation(self, attn_implementation: Union[dict, str]):
820824
# ModernBert's FA2 implementation correctly handles non-fp16/bf16 dtypes, we don't
821825
# need the FA2 warning for non-fp16/bf16 dtypes so we set fp16 for the FA2 check.
822826

823-
requested_attn_implementation = self._check_attn_implementation(attn_implementation)
824827
try:
825828
attn_implementation = (
826829
"flash_attention_2"
827-
if requested_attn_implementation is None and self._flash_attn_2_can_dispatch()
830+
if attn_implementation is None and self._flash_attn_2_can_dispatch()
828831
else attn_implementation
829832
)
830833
except (ValueError, ImportError):
831834
pass
832-
return super().set_attention_implementation(attn_implementation=attn_implementation)
835+
return super()._check_and_adjust_attn_implementation(
836+
attn_implementation=attn_implementation, is_init_check=is_init_check
837+
)
833838

834839
def _maybe_set_compile(self):
835840
if self.config.reference_compile is False:

tests/models/modernbert/test_modeling_modernbert.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,16 @@ def test_saved_config_excludes_reference_compile(self):
375375
config_dict = json.load(f)
376376
self.assertNotIn("reference_compile", config_dict)
377377

378+
@require_flash_attn
379+
@require_torch_gpu
380+
@pytest.mark.flash_attn_test
381+
def test_flash_attention_dispatches_by_defaul(self):
382+
"ModernBert should dispatch to FA2 by default, not SDPA"
383+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
384+
for model_class in self.all_model_classes:
385+
model = model_class(config=config)
386+
self.assertTrue(model.config._attn_implementation == "flash_attention_2")
387+
378388

379389
@require_torch
380390
class ModernBertModelIntegrationTest(unittest.TestCase):

0 commit comments

Comments
 (0)