@@ -510,6 +510,39 @@ def llama_gemma_rotary_emb_forward(self, x, position_ids, seq_len=None):
510510 return cos , sin
511511
512512
513+ def create_sinusoidal_positions (num_pos : int , dim : int , base : int = 10000 , inv_freq = None ) -> torch .Tensor :
514+ # adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L101
515+ if inv_freq is None :
516+ inv_freq = 1.0 / (base ** (torch .arange (0 , dim , 2 , dtype = torch .int64 ) / dim ))
517+
518+ sinusoid_inp = torch .einsum ("i , j -> i j" , torch .arange (num_pos , dtype = torch .int64 ).float (), inv_freq ).float ()
519+ emb = torch .cat ((sinusoid_inp , sinusoid_inp ), dim = - 1 )
520+ return torch .cat ((torch .sin (emb ), torch .cos (emb )), dim = 1 )
521+
522+
523+ def register_sin_cos_buffer (model ):
524+ max_positions = model .config .max_position_embeddings
525+
526+ # cos/sin for rotary position embeddings also having issues with bf16 and efficiency due to calculation on each step
527+ # use precomputed
528+
529+ rotary_emb = model .model .layers [0 ].self_attn .rotary_emb
530+ dim , base = None , None
531+ inv_freq = getattr (rotary_emb , "inv_freq" , None )
532+ if inv_freq is None :
533+ base = rotary_emb .base
534+ dim = rotary_emb .dim
535+ embed_positions = create_sinusoidal_positions (max_positions , dim , base , inv_freq )
536+
537+ for layer in model .model .layers :
538+ layer .self_attn .rotary_emb .register_buffer ("embed_positions" , embed_positions )
539+ layer .self_attn .rotary_emb ._orig_forward = layer .self_attn .rotary_emb .forward
540+
541+ layer .self_attn .rotary_emb .forward = types .MethodType (
542+ llama_gemma_rotary_emb_forward , layer .self_attn .rotary_emb
543+ )
544+
545+
513546class LlamaModelPatcher (DecoderModelPatcher ):
514547 def __enter__ (self ):
515548 super ().__enter__ ()
@@ -521,39 +554,148 @@ def __enter__(self):
521554 self ._model .model ._update_causal_mask = types .MethodType (
522555 _llama_gemma_update_causal_mask , self ._model .model
523556 )
557+ register_sin_cos_buffer (self ._model )
524558
525- max_positions = self ._model .config .max_position_embeddings
559+ def __exit__ (self , exc_type , exc_value , traceback ):
560+ super ().__exit__ (exc_type , exc_value , traceback )
561+ if hasattr (self ._model .model , "_orig_update_causal_mask" ):
562+ self ._model .model ._update_causal_mask = self ._model .model ._orig_update_causal_mask
526563
527- # cos/sin for rotary position embeddings also having issues with bf16 and efficiency due to calculation on each step
528- # use precomputed
529- def create_sinusoidal_positions (num_pos : int , dim : int , base : int = 10000 ) -> torch .Tensor :
530- # adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L101
531- inv_freq = 1.0 / (base ** (torch .arange (0 , dim , 2 , dtype = torch .int64 ) / dim ))
564+ for layer in self ._model .model .layers :
565+ layer .self_attn .rotary_emb .forward = layer .self_attn .rotary_emb ._orig_forward
532566
533- sinusoid_inp = torch .einsum (
534- "i , j -> i j" , torch .arange (num_pos , dtype = torch .int64 ).float (), inv_freq
535- ).float ()
536- emb = torch .cat ((sinusoid_inp , sinusoid_inp ), dim = - 1 )
537- return torch .cat ((torch .sin (emb ), torch .cos (emb )), dim = 1 )
538567
539- base = self ._model .model .layers [0 ].self_attn .rotary_emb .base
540- dim = self ._model .model .layers [0 ].self_attn .rotary_emb .dim
541- embed_positions = create_sinusoidal_positions (max_positions , dim , base )
568+ # copied from https://github.com/huggingface/transformers/commit/57d7594a79a9f5d835abf2d4d384db0e4818e548 to unblock export with transformers 4.42
569+ def _mistral_update_causal_mask (
570+ self ,
571+ attention_mask : torch .Tensor ,
572+ input_tensor : torch .Tensor ,
573+ cache_position : torch .Tensor ,
574+ past_key_values : "Cache" ,
575+ use_cache : bool ,
576+ output_attentions : bool ,
577+ ):
578+ from transformers .cache_utils import SlidingWindowCache , StaticCache
579+ from transformers .modeling_attn_mask_utils import AttentionMaskConverter
542580
543- for layer in self ._model .model .layers :
544- layer .self_attn .rotary_emb .register_buffer ("embed_positions" , embed_positions )
545- layer .self_attn .rotary_emb ._orig_forward = layer .self_attn .rotary_emb .forward
581+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
582+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
583+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
584+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
585+
586+ if self ._attn_implementation == "flash_attention_2" :
587+ if attention_mask is not None and use_cache :
588+ is_padding_right = attention_mask [:, - 1 ].sum ().item () != input_tensor .size ()[0 ]
589+ if is_padding_right :
590+ raise ValueError (
591+ "You are attempting to perform batched generation with padding_side='right'"
592+ " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
593+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
594+ )
595+ if attention_mask is not None and 0.0 in attention_mask :
596+ return attention_mask
597+ return None
598+
599+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
600+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
601+ # to infer the attention mask.
602+
603+ # cache_position must be valid here no matter which cache we use
604+ past_seen_tokens = cache_position [0 ] if past_key_values is not None else 0
605+ using_static_cache = isinstance (past_key_values , StaticCache )
606+ using_sliding_window_cache = isinstance (past_key_values , SlidingWindowCache )
607+
608+ if (
609+ self .config ._attn_implementation == "sdpa"
610+ and not (using_static_cache or using_sliding_window_cache )
611+ and not output_attentions
612+ ):
613+ if AttentionMaskConverter ._ignore_causal_mask_sdpa (
614+ attention_mask ,
615+ inputs_embeds = input_tensor ,
616+ past_key_values_length = past_seen_tokens ,
617+ sliding_window = self .config .sliding_window ,
618+ is_training = self .training ,
619+ ):
620+ return None
546621
547- layer .self_attn .rotary_emb .forward = types .MethodType (
548- llama_gemma_rotary_emb_forward , layer .self_attn .rotary_emb
622+ dtype , device = input_tensor .dtype , input_tensor .device
623+ min_dtype = torch .finfo (dtype ).min
624+ sequence_length = input_tensor .shape [1 ]
625+ # SlidingWindowCache
626+ if using_sliding_window_cache :
627+ target_length = max (sequence_length , self .config .sliding_window )
628+ # StaticCache
629+ elif using_static_cache :
630+ target_length = past_key_values .get_max_length ()
631+ # DynamicCache or no cache
632+ else :
633+ target_length = (
634+ attention_mask .shape [- 1 ]
635+ if isinstance (attention_mask , torch .Tensor )
636+ else past_seen_tokens + sequence_length + 1
637+ )
638+
639+ if attention_mask is not None and attention_mask .dim () == 4 :
640+ # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
641+ if attention_mask .max () != 0 :
642+ raise ValueError ("Custom 4D attention mask should be passed in inverted form with max==0`" )
643+ causal_mask = attention_mask
644+ else :
645+ causal_mask = torch .full ((sequence_length , target_length ), fill_value = min_dtype , dtype = dtype , device = device )
646+ exclude_mask = torch .arange (target_length , device = device ) > cache_position .reshape (- 1 , 1 )
647+ if self .config .sliding_window is not None :
648+ if not using_sliding_window_cache or sequence_length > self .config .sliding_window :
649+ exclude_mask = exclude_mask .bitwise_or (
650+ torch .arange (target_length , device = device )
651+ <= (cache_position .reshape (- 1 , 1 ) - self .config .sliding_window )
652+ )
653+ causal_mask *= exclude_mask
654+ causal_mask = causal_mask [None , None , :, :].expand (input_tensor .shape [0 ], 1 , - 1 , - 1 )
655+ if attention_mask is not None :
656+ causal_mask = causal_mask .clone () # copy to contiguous memory for in-place edit
657+ if attention_mask .dim () == 2 :
658+ mask_length = attention_mask .shape [- 1 ]
659+ padding_mask = causal_mask [:, :, :, :mask_length ] + attention_mask [:, None , None , :]
660+ padding_mask = padding_mask == 0
661+ causal_mask [:, :, :, :mask_length ] = causal_mask [:, :, :, :mask_length ].masked_fill (
662+ padding_mask , min_dtype
549663 )
550664
665+ if (
666+ self .config ._attn_implementation == "sdpa"
667+ and attention_mask is not None
668+ and attention_mask .device .type == "cuda"
669+ and not output_attentions
670+ ):
671+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
672+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
673+ # Details: https://github.com/pytorch/pytorch/issues/110213
674+ causal_mask = AttentionMaskConverter ._unmask_unattended (causal_mask , min_dtype )
675+
676+ return causal_mask
677+
678+
679+ class MistralModelPatcher (DecoderModelPatcher ):
680+ def __enter__ (self ):
681+ super ().__enter__ ()
682+ if is_transformers_version (">=" , "4.42.0" ):
683+ # apply fix https://github.com/huggingface/transformers/commit/57d7594a79a9f5d835abf2d4d384db0e4818e548
684+ self ._model .model ._orig_update_causal_mask = self ._model .model ._update_causal_mask
685+ self ._model .model ._update_causal_mask = types .MethodType (_mistral_update_causal_mask , self ._model .model )
686+
687+ # mistral has some accuracy issues with bf16 with transformers >= 4.42
688+ # prefill rotary emb sin/cos for avoid this issue
689+ register_sin_cos_buffer (self ._model )
690+
551691 def __exit__ (self , exc_type , exc_value , traceback ):
552692 super ().__exit__ (exc_type , exc_value , traceback )
693+
553694 if hasattr (self ._model .model , "_orig_update_causal_mask" ):
554695 self ._model .model ._update_causal_mask = self ._model .model ._orig_update_causal_mask
555696
556- for layer in self ._model .model .layers :
697+ for layer in self ._model .model .layers :
698+ if hasattr (layer .self_attn .rotary_emb , "_orig_forward" ):
557699 layer .self_attn .rotary_emb .forward = layer .self_attn .rotary_emb ._orig_forward
558700
559701
@@ -1283,11 +1425,15 @@ def __enter__(self):
12831425 rotary_emb .base ** (torch .arange (0 , rotary_emb .dim , 2 , dtype = torch .int64 ).float () / rotary_emb .dim )
12841426 )
12851427
1428+ # phi3 has issue with bf16 inference, precollect sin/cos for rotary_position_embedding for avoid accuracy issues
1429+ register_sin_cos_buffer (self ._model )
1430+
12861431 def __exit__ (self , exc_type , exc_value , traceback ):
12871432 super ().__exit__ (exc_type , exc_value , traceback )
12881433 for layer in self ._model .model .layers :
12891434 if hasattr (layer .self_attn , "_orig_forward" ):
12901435 layer .self_attn .forward = layer .self_attn ._orig_forward
1436+ layer .self_attn .rotary_emb .forward = layer .self_attn .rotary_emb ._orig_forward
12911437
12921438
12931439def _aquila_self_attn_sdpa_forward (
@@ -1807,6 +1953,18 @@ def __enter__(self):
18071953 _dbrx_update_causal_mask , self ._model .transformer
18081954 )
18091955
1956+ # starting from transformers 4.41 issue also observable for calculation sin/cos for rotary_emb
1957+ patch_rope_sin_cos = is_transformers_version (">=" , "4.41.0" )
1958+
1959+ inv_freq = getattr (self ._model .transformer .blocks [0 ].norm_attn_norm .attn .rotary_emb , "inv_freq" )
1960+ dim , base = None , None
1961+ if inv_freq is None :
1962+ dim = self ._model .transformer .blocks [0 ].norm_attn_norm .attn .rotary_emb .dim
1963+ base = self ._model .transformer .blocks [0 ].norm_attn_norm .attn .rotary_emb .base
1964+ max_positions = self ._model .config .max_seq_len
1965+ if patch_rope_sin_cos :
1966+ embed_positions = create_sinusoidal_positions (max_positions , dim , base , inv_freq )
1967+
18101968 for block in self ._model .transformer .blocks :
18111969 rotary_emb = block .norm_attn_norm .attn .rotary_emb
18121970 # initialize inv_freq for torchscript tracing
@@ -1815,6 +1973,12 @@ def __enter__(self):
18151973 rotary_emb .base ** (torch .arange (0 , rotary_emb .dim , 2 , dtype = torch .int64 ).float () / rotary_emb .dim )
18161974 )
18171975 rotary_emb .inv_freq = inv_freq
1976+
1977+ if patch_rope_sin_cos :
1978+ rotary_emb .register_buffer ("embed_positions" , embed_positions )
1979+ rotary_emb ._orig_forward = rotary_emb .forward
1980+ rotary_emb .forward = types .MethodType (llama_gemma_rotary_emb_forward , rotary_emb )
1981+
18181982 # remove continue-operator from iteration loop over experts
18191983 block .ffn .experts ._orig_forward = block .ffn .experts .forward
18201984 block .ffn .experts .forward = types .MethodType (_dbrx_experts_forward , block .ffn .experts )
@@ -1825,6 +1989,9 @@ def __exit__(self, exc_type, exc_value, traceback):
18251989 for block in self ._model .transformer .blocks :
18261990 block .ffn .experts .forward = block .ffn .experts ._orig_forward
18271991
1992+ if hasattr (block .norm_attn_norm .attn .rotary_emb , "_orig_forward" ):
1993+ block .norm_attn_norm .attn .rotary_emb .forward = block .norm_attn_norm .attn .rotary_emb ._orig_forward
1994+
18281995
18291996# Adapted from https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/models/persimmon/modeling_persimmon.py#L264
18301997def _persimmon_self_attn_sdpa_forward (
0 commit comments