@@ -497,50 +497,65 @@ def _llama_gemma_update_causal_mask_latest(
497497 _llama_gemma_update_causal_mask = _llama_gemma_update_causal_mask_legacy
498498
499499
500- class GemmaModelPatcher (DecoderModelPatcher ):
500+ def llama_gemma_rotary_emb_forward (self , x , position_ids , seq_len = None ):
501+ # adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma/modeling_gemma.py#L104
502+ _seq_len = torch .max (position_ids ) + 1 if seq_len is None else seq_len
503+ if _seq_len > self .embed_positions .shape [0 ]:
504+ if seq_len is None :
505+ return self ._orig_forward (x , position_ids )
506+ else :
507+ return self ._orig_forward (x , position_ids , seq_len )
508+ sincos = self .embed_positions [position_ids ]
509+ sin , cos = torch .split (sincos , sincos .shape [- 1 ] // 2 , dim = - 1 )
510+ return cos , sin
511+
512+
513+ class LlamaModelPatcher (DecoderModelPatcher ):
501514 def __enter__ (self ):
502515 super ().__enter__ ()
503516
504- # gemma has some accuracy issues with bf16 with transformers >= 4.39
517+ # llama/ gemma has some accuracy issues with bf16 with transformers >= 4.39
505518 # fill causal mask in slightly different way for avoid overflow on some platforms
506519 if is_transformers_version (">=" , "4.39.0" ):
507520 self ._model .model ._orig_update_causal_mask = self ._model .model ._update_causal_mask
508521 self ._model .model ._update_causal_mask = types .MethodType (
509522 _llama_gemma_update_causal_mask , self ._model .model
510523 )
511524
512- # init inv_freq for torchscript tracing
513- # https://github.com/huggingface/transformers/blob/ed74d97871468f3a4695ede50abdc0b55717a84d/src/transformers/models/gemma/modeling_gemma.py#L108
514- for layer in self ._model .model .layers :
515- if layer .self_attn .rotary_emb .inv_freq is None :
516- rotary_emb = layer .self_attn .rotary_emb
517- layer .self_attn .rotary_emb .inv_freq = 1.0 / (
518- rotary_emb .base ** (torch .arange (0 , rotary_emb .dim , 2 , dtype = torch .int64 ).float () / rotary_emb .dim )
519- )
525+ max_positions = self ._model .config .max_position_embeddings
520526
521- def __exit__ (self , exc_type , exc_value , traceback ):
522- super ().__exit__ (exc_type , exc_value , traceback )
523- if hasattr (self ._model .model , "_orig_update_causal_mask" ):
524- self ._model .model ._update_causal_mask = self ._model .model ._orig_update_causal_mask
527+ # cos/sin for rotary position embeddings also having issues with bf16 and efficiency due to calculation on each step
528+ # use precomputed
529+ def create_sinusoidal_positions (num_pos : int , dim : int , base : int = 10000 ) -> torch .Tensor :
530+ # adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L101
531+ inv_freq = 1.0 / (base ** (torch .arange (0 , dim , 2 , dtype = torch .int64 ) / dim ))
525532
533+ sinusoid_inp = torch .einsum (
534+ "i , j -> i j" , torch .arange (num_pos , dtype = torch .int64 ).float (), inv_freq
535+ ).float ()
536+ emb = torch .cat ((sinusoid_inp , sinusoid_inp ), dim = - 1 )
537+ return torch .cat ((torch .sin (emb ), torch .cos (emb )), dim = 1 )
526538
527- class LlamaModelPatcher ( DecoderModelPatcher ):
528- def __enter__ ( self ):
529- super (). __enter__ ( )
539+ base = self . _model . model . layers [ 0 ]. self_attn . rotary_emb . base
540+ dim = self . _model . model . layers [ 0 ]. self_attn . rotary_emb . dim
541+ embed_positions = create_sinusoidal_positions ( max_positions , dim , base )
530542
531- # llama has some accuracy issues with bf16 with transformers >= 4.39
532- # fill causal mask in slightly different way for avoid overflow on some platforms
533- if is_transformers_version ( ">=" , "4.39.0" ):
534- self . _model . model . _orig_update_causal_mask = self . _model . model . _update_causal_mask
535- self . _model . model . _update_causal_mask = types .MethodType (
536- _llama_gemma_update_causal_mask , self . _model . model
537- )
543+ for layer in self . _model . model . layers :
544+ layer . self_attn . rotary_emb . register_buffer ( "embed_positions" , embed_positions )
545+ layer . self_attn . rotary_emb . _orig_forward = layer . self_attn . rotary_emb . forward
546+
547+ layer . self_attn . rotary_emb . forward = types .MethodType (
548+ llama_gemma_rotary_emb_forward , layer . self_attn . rotary_emb
549+ )
538550
539551 def __exit__ (self , exc_type , exc_value , traceback ):
540552 super ().__exit__ (exc_type , exc_value , traceback )
541553 if hasattr (self ._model .model , "_orig_update_causal_mask" ):
542554 self ._model .model ._update_causal_mask = self ._model .model ._orig_update_causal_mask
543555
556+ for layer in self ._model .model .layers :
557+ layer .self_attn .rotary_emb .forward = layer .self_attn .rotary_emb ._orig_forward
558+
544559
545560SUPPORT_SDPA = is_torch_version (">" , "2.1.0" )
546561
0 commit comments