@@ -2467,3 +2467,134 @@ def patched_forward(*args, **kwargs):
24672467 return outputs
24682468
24692469 self .patched_forward = patched_forward
2470+
2471+
2472+ def _decilm_attn_forward (
2473+ self ,
2474+ hidden_states : torch .Tensor ,
2475+ attention_mask : Optional [torch .Tensor ] = None ,
2476+ position_ids : Optional [torch .LongTensor ] = None ,
2477+ past_key_value : Optional [Tuple [torch .Tensor ]] = None ,
2478+ output_attentions : bool = False ,
2479+ use_cache : bool = False ,
2480+ ** kwargs ,
2481+ ) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [Tuple [torch .Tensor ]]]:
2482+ # decilm contains bug in attention calculation for case if past key values is not None
2483+ def rotate_half (x ):
2484+ """Rotates half the hidden dims of the input."""
2485+ x1 = x [..., : x .shape [- 1 ] // 2 ]
2486+ x2 = x [..., x .shape [- 1 ] // 2 :]
2487+ return torch .cat ((- x2 , x1 ), dim = - 1 )
2488+
2489+ def apply_rotary_pos_emb (q , k , cos , sin , position_ids , unsqueeze_dim = 1 ):
2490+ """Applies Rotary Position Embedding to the query and key tensors.
2491+
2492+ Args:
2493+ q (`torch.Tensor`): The query tensor.
2494+ k (`torch.Tensor`): The key tensor.
2495+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
2496+ sin (`torch.Tensor`): The sine part of the rotary embedding.
2497+ position_ids (`torch.Tensor`):
2498+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
2499+ used to pass offsetted position ids when working with a KV-cache.
2500+ unsqueeze_dim (`int`, *optional*, defaults to 1):
2501+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
2502+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
2503+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
2504+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
2505+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
2506+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
2507+ Returns:
2508+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
2509+ """
2510+ cos = cos [position_ids ].unsqueeze (unsqueeze_dim )
2511+ sin = sin [position_ids ].unsqueeze (unsqueeze_dim )
2512+ q_embed = (q * cos ) + (rotate_half (q ) * sin )
2513+ k_embed = (k * cos ) + (rotate_half (k ) * sin )
2514+ return q_embed , k_embed
2515+
2516+ def repeat_kv (hidden_states : torch .Tensor , n_rep : int ) -> torch .Tensor :
2517+ """
2518+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
2519+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
2520+ """
2521+ batch , num_key_value_heads , slen , head_dim = hidden_states .shape
2522+ if n_rep == 1 :
2523+ return hidden_states
2524+ hidden_states = hidden_states [:, :, None , :, :].expand (batch , num_key_value_heads , n_rep , slen , head_dim )
2525+ return hidden_states .reshape (batch , num_key_value_heads * n_rep , slen , head_dim )
2526+
2527+ bsz , q_len , _ = hidden_states .size ()
2528+ if self .pretraining_tp > 1 :
2529+ key_value_slicing = (self .num_key_value_heads * self .head_dim ) // self .pretraining_tp
2530+ query_slices = self .q_proj .weight .split ((self .num_heads * self .head_dim ) // self .pretraining_tp , dim = 0 )
2531+ key_slices = self .k_proj .weight .split (key_value_slicing , dim = 0 )
2532+ value_slices = self .v_proj .weight .split (key_value_slicing , dim = 0 )
2533+
2534+ query_states = [F .linear (hidden_states , query_slices [i ]) for i in range (self .pretraining_tp )]
2535+ query_states = torch .cat (query_states , dim = - 1 )
2536+
2537+ key_states = [F .linear (hidden_states , key_slices [i ]) for i in range (self .pretraining_tp )]
2538+ key_states = torch .cat (key_states , dim = - 1 )
2539+
2540+ value_states = [F .linear (hidden_states , value_slices [i ]) for i in range (self .pretraining_tp )]
2541+ value_states = torch .cat (value_states , dim = - 1 )
2542+
2543+ else :
2544+ query_states = self .q_proj (hidden_states )
2545+ key_states = self .k_proj (hidden_states )
2546+ value_states = self .v_proj (hidden_states )
2547+
2548+ query_states = query_states .view (bsz , q_len , self .num_heads , self .head_dim ).transpose (1 , 2 )
2549+ key_states = key_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim ).transpose (1 , 2 )
2550+ value_states = value_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim ).transpose (1 , 2 )
2551+
2552+ kv_seq_len = key_states .shape [- 2 ]
2553+ if past_key_value is not None :
2554+ kv_seq_len += past_key_value [0 ].shape [- 2 ]
2555+ cos , sin = self .rotary_emb (value_states , seq_len = kv_seq_len )
2556+
2557+ query_states , key_states = apply_rotary_pos_emb (query_states , key_states , cos , sin , position_ids )
2558+
2559+ if past_key_value is not None :
2560+ # reuse k, v, self_attention
2561+ key_states = torch .cat ([past_key_value [0 ], key_states ], dim = 2 )
2562+ value_states = torch .cat ([past_key_value [1 ], value_states ], dim = 2 )
2563+
2564+ past_key_value = (key_states , value_states ) if use_cache else None
2565+
2566+ # repeat k/v heads if n_kv_heads < n_heads
2567+ key_states = repeat_kv (key_states , self .num_key_value_groups )
2568+ value_states = repeat_kv (value_states , self .num_key_value_groups )
2569+ attn_output = F .scaled_dot_product_attention (
2570+ query_states , key_states , value_states , is_causal = attention_mask is None , attn_mask = attention_mask
2571+ )
2572+
2573+ # modified, in original implementation .transpose(1, 2) missed
2574+ attn_output = attn_output .transpose (1 , 2 ).contiguous ().view (bsz , q_len , self .hidden_size )
2575+
2576+ if self .pretraining_tp > 1 :
2577+ attn_output = attn_output .split (self .hidden_size // self .pretraining_tp , dim = 2 )
2578+ o_proj_slices = self .o_proj .weight .split (self .hidden_size // self .pretraining_tp , dim = 1 )
2579+ attn_output = sum ([F .linear (attn_output [i ], o_proj_slices [i ]) for i in range (self .pretraining_tp )])
2580+ else :
2581+ attn_output = self .o_proj (attn_output )
2582+
2583+ attn_weights = None
2584+
2585+ return attn_output , attn_weights , past_key_value
2586+
2587+
2588+ class DeciLMModelPatcher (DecoderModelPatcher ):
2589+ def __enter__ (self ):
2590+ super ().__enter__ ()
2591+
2592+ for layer in self ._model .model .layers :
2593+ layer .self_attn ._orig_forward = layer .self_attn .forward
2594+ layer .self_attn .forward = types .MethodType (_decilm_attn_forward , layer .self_attn )
2595+
2596+ def __exit__ (self , exc_type , exc_value , traceback ):
2597+ super ().__exit__ (exc_type , exc_value , traceback )
2598+
2599+ for layer in self ._model .model .layers :
2600+ layer .self_attn .forward = layer .self_attn ._orig_forward
0 commit comments