1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16+ import copy
1617import math
1718from contextlib import nullcontext
1819from typing import Literal , Optional , Union
@@ -659,20 +660,21 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None):
659660
660661 if layer_id % config .global_attn_every_n_layers != 0 :
661662 self .local_attention = (config .local_attention // 2 , config .local_attention // 2 )
663+ rope_theta = config .local_rope_theta if config .local_rope_theta is not None else config .global_rope_theta
664+ max_position_embeddings = config .local_attention
662665 else :
663666 self .local_attention = (- 1 , - 1 )
664-
665- max_position_embeddings = config .max_position_embeddings
666- if self .local_attention != (- 1 , - 1 ):
667- rope_theta = config .global_rope_theta if config .local_rope_theta is None else config .local_rope_theta
668- max_position_embeddings = config .local_attention
667+ max_position_embeddings = config .max_position_embeddings
668+ rope_theta = config .global_rope_theta
669669
670670 if config ._attn_implementation == "flash_attention_2" :
671671 self .rotary_emb = ModernBertUnpaddedRotaryEmbedding (
672672 dim = self .head_dim , max_seqlen = max_position_embeddings , base = rope_theta
673673 )
674674 else :
675- self .rotary_emb = ModernBertRotaryEmbedding (config = config )
675+ config_copy = copy .deepcopy (config )
676+ config_copy .rope_theta = rope_theta
677+ self .rotary_emb = ModernBertRotaryEmbedding (config = config_copy )
676678
677679 self .Wo = nn .Linear (config .hidden_size , config .hidden_size , bias = config .attention_bias )
678680 self .out_drop = nn .Dropout (config .attention_dropout ) if config .attention_dropout > 0.0 else nn .Identity ()
@@ -811,7 +813,9 @@ def init_weight(module: nn.Module, std: float):
811813 if module .bias is not None :
812814 module .bias .data .zero_ ()
813815
814- def set_attention_implementation (self , attn_implementation : Union [dict , str ]):
816+ def _check_and_adjust_attn_implementation (
817+ self , attn_implementation : Optional [str ], is_init_check : bool = False
818+ ) -> str :
815819 """
816820 Checks and dispatches to hhe requested attention implementation.
817821 """
@@ -820,16 +824,17 @@ def set_attention_implementation(self, attn_implementation: Union[dict, str]):
820824 # ModernBert's FA2 implementation correctly handles non-fp16/bf16 dtypes, we don't
821825 # need the FA2 warning for non-fp16/bf16 dtypes so we set fp16 for the FA2 check.
822826
823- requested_attn_implementation = self ._check_attn_implementation (attn_implementation )
824827 try :
825828 attn_implementation = (
826829 "flash_attention_2"
827- if requested_attn_implementation is None and self ._flash_attn_2_can_dispatch ()
830+ if attn_implementation is None and self ._flash_attn_2_can_dispatch ()
828831 else attn_implementation
829832 )
830833 except (ValueError , ImportError ):
831834 pass
832- return super ().set_attention_implementation (attn_implementation = attn_implementation )
835+ return super ()._check_and_adjust_attn_implementation (
836+ attn_implementation = attn_implementation , is_init_check = is_init_check
837+ )
833838
834839 def _maybe_set_compile (self ):
835840 if self .config .reference_compile is False :
0 commit comments