@@ -3237,3 +3237,144 @@ def __init__(
32373237 def __exit__ (self , exc_type , exc_value , traceback ):
32383238 super ().__exit__ (exc_type , exc_value , traceback )
32393239 self ._model .forward = self ._model .__orig_forward
3240+
3241+
3242+ def minicpm3_attn_forward (
3243+ self ,
3244+ hidden_states : torch .Tensor ,
3245+ attention_mask : Optional [torch .Tensor ] = None ,
3246+ position_ids : Optional [torch .LongTensor ] = None ,
3247+ past_key_value = None ,
3248+ output_attentions : bool = False ,
3249+ use_cache : bool = False ,
3250+ ) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [Tuple [torch .Tensor ]]]:
3251+ def rotate_half (x ):
3252+ """Rotates half the hidden dims of the input."""
3253+ x1 = x [..., : x .shape [- 1 ] // 2 ]
3254+ x2 = x [..., x .shape [- 1 ] // 2 :]
3255+ return torch .cat ((- x2 , x1 ), dim = - 1 )
3256+
3257+ def apply_rotary_pos_emb (q , k , cos , sin , position_ids , unsqueeze_dim = 1 ):
3258+ """Applies Rotary Position Embedding to the query and key tensors.
3259+ Args:
3260+ q (`torch.Tensor`): The query tensor.
3261+ k (`torch.Tensor`): The key tensor.
3262+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
3263+ sin (`torch.Tensor`): The sine part of the rotary embedding.
3264+ position_ids (`torch.Tensor`):
3265+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
3266+ used to pass offsetted position ids when working with a KV-cache.
3267+ unsqueeze_dim (`int`, *optional*, defaults to 1):
3268+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
3269+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
3270+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
3271+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
3272+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
3273+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
3274+ Returns:
3275+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
3276+ """
3277+ orig_dtype = k .dtype
3278+ cos = cos [position_ids ].unsqueeze (unsqueeze_dim ) # [bs, 1, seq_len, dim]
3279+ sin = sin [position_ids ].unsqueeze (unsqueeze_dim ) # [bs, 1, seq_len, dim]
3280+ q_fp32 = q .to (dtype = torch .float32 , device = q .device )
3281+ k_fp32 = k .to (dtype = torch .float32 , device = k .device )
3282+ q_embed = (q_fp32 * cos ) + (rotate_half (q_fp32 ) * sin )
3283+ k_embed = (k_fp32 * cos ) + (rotate_half (k_fp32 ) * sin )
3284+ return q_embed .to (dtype = orig_dtype ), k_embed .to (dtype = orig_dtype )
3285+
3286+ if output_attentions :
3287+ return self ._orig_forward (
3288+ hidden_states = hidden_states ,
3289+ attention_mask = attention_mask ,
3290+ position_ids = position_ids ,
3291+ past_key_value = past_key_value ,
3292+ output_attentions = output_attentions ,
3293+ use_cache = use_cache ,
3294+ )
3295+
3296+ bsz , q_len , _ = hidden_states .shape
3297+
3298+ q = self .q_b_proj (self .q_a_layernorm (self .q_a_proj (hidden_states )))
3299+ q = q .view (hidden_states .shape [0 ], hidden_states .shape [1 ], self .num_heads , self .q_head_dim ).transpose (1 , 2 )
3300+ q_nope , q_pe = torch .split (q , [self .qk_nope_head_dim , self .qk_rope_head_dim ], dim = - 1 )
3301+
3302+ compressed_kv = self .kv_a_proj_with_mqa (hidden_states )
3303+ compressed_kv , k_pe = torch .split (compressed_kv , [self .kv_lora_rank , self .qk_rope_head_dim ], dim = - 1 )
3304+ k_pe = k_pe .view (hidden_states .shape [0 ], hidden_states .shape [1 ], 1 , self .qk_rope_head_dim ).transpose (1 , 2 )
3305+ kv = (
3306+ self .kv_b_proj (self .kv_a_layernorm (compressed_kv ))
3307+ .view (hidden_states .shape [0 ], hidden_states .shape [1 ], self .num_heads , self .qk_nope_head_dim + self .v_head_dim )
3308+ .transpose (1 , 2 )
3309+ )
3310+
3311+ k_nope , value_states = torch .split (kv , [self .qk_nope_head_dim , self .v_head_dim ], dim = - 1 )
3312+
3313+ kv_seq_len = value_states .shape [- 2 ]
3314+ if past_key_value is not None :
3315+ if self .layer_idx is None :
3316+ raise ValueError (
3317+ f"The cache structure has changed since version v4.36. If you are using { self .__class__ .__name__ } "
3318+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
3319+ "with a layer index."
3320+ )
3321+ kv_seq_len += past_key_value .get_usable_length (kv_seq_len , self .layer_idx )
3322+ cos , sin = self .rotary_emb (value_states , seq_len = kv_seq_len )
3323+
3324+ q_pe , k_pe = apply_rotary_pos_emb (q_pe , k_pe , cos , sin , position_ids )
3325+
3326+ # Difference with original code, k_pe.new_empty create constant tensor in torchscript
3327+ query_states = torch .concat ([q_nope , q_pe ], dim = - 1 )
3328+ # query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
3329+ # query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
3330+ # query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
3331+ key_states = torch .concat ([k_nope , k_pe .expand (- 1 , self .num_heads , - 1 , - 1 )], dim = - 1 )
3332+ # key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
3333+ # key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
3334+ # key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
3335+ if past_key_value is not None :
3336+ cache_kwargs = {"sin" : sin , "cos" : cos } # Specific to RoPE models
3337+ key_states , value_states = past_key_value .update (key_states , value_states , self .layer_idx , cache_kwargs )
3338+
3339+ if attention_mask is not None :
3340+ if attention_mask .size () != (bsz , 1 , q_len , kv_seq_len ):
3341+ raise ValueError (
3342+ f"Attention mask should be of size { (bsz , 1 , q_len , kv_seq_len )} , but is { attention_mask .size ()} "
3343+ )
3344+
3345+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
3346+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
3347+ if query_states .device .type == "cuda" and attention_mask is not None :
3348+ query_states = query_states .contiguous ()
3349+ key_states = key_states .contiguous ()
3350+ value_states = value_states .contiguous ()
3351+
3352+ attn_output = torch .nn .functional .scaled_dot_product_attention (
3353+ query_states ,
3354+ key_states ,
3355+ value_states ,
3356+ attn_mask = attention_mask ,
3357+ dropout_p = self .attention_dropout if self .training else 0.0 ,
3358+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
3359+ is_causal = self .is_causal and attention_mask is None and q_len > 1 ,
3360+ )
3361+
3362+ attn_output = attn_output .transpose (1 , 2 ).contiguous ()
3363+ attn_output = attn_output .reshape (hidden_states .shape [0 ], hidden_states .shape [1 ], self .hidden_size )
3364+
3365+ attn_output = self .o_proj (attn_output )
3366+
3367+ return attn_output , None , past_key_value
3368+
3369+
3370+ class MiniCPM3Patcher (DecoderModelPatcher ):
3371+ def __enter__ (self ):
3372+ super ().__enter__ ()
3373+ for block in self ._model .model .layers :
3374+ block .self_attn ._orig_forward = block .self_attn .forward
3375+ block .self_attn .forward = types .MethodType (minicpm3_attn_forward , block .self_attn )
3376+
3377+ def __exit__ (self , exc_type , exc_value , traceback ):
3378+ super ().__exit__ (exc_type , exc_value , traceback )
3379+ for block in self ._model .model .layers :
3380+ block .self_attn .forward = block .self_attn ._orig_forward
0 commit comments