@@ -48,4 +48,111 @@ def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):
4848
4949def apply_rotary_pos_emb_torch (q , k , cos , sin , offset : int = 0 ): # jitting fails with bf16
5050 cos , sin = cos [offset :q .shape [0 ] + offset , ...], sin [offset :q .shape [0 ] + offset , ...]
51- return (q * cos ) + (rotate_half (q ) * sin ), (k * cos ) + (rotate_half (k ) * sin )
51+ return (q * cos ) + (rotate_half (q ) * sin ), (k * cos ) + (rotate_half (k ) * sin )
52+
53+
54+ # Original implementation adjusted from https://github.com/sunyt32/torchscale
55+
56+ def fixed_pos_embedding (x , base ):
57+ seq_len , dim = x .shape
58+ inv_freq = 1.0 / (base ** (torch .arange (0 , dim ) / dim ))
59+ sinusoid_inp = (
60+ torch .einsum ("i , j -> i j" , torch .arange (0 , seq_len , dtype = torch .float ), inv_freq ).to (x )
61+ )
62+ return torch .cos (sinusoid_inp ), torch .sin (sinusoid_inp )
63+
64+
65+ class XPos (torch .nn .Module ):
66+ """
67+ xPos positional embeddings from https://arxiv.org/abs/2212.10554.
68+ """
69+
70+ def __init__ (self , head_dim , freq_base = 10000 , scale_base = 512 , gamma = 0.4 , precision = torch .half ):
71+ super ().__init__ ()
72+ self .scale_base = scale_base
73+ self .register_buffer (
74+ "scale" ,
75+ (
76+ (torch .arange (0 , head_dim , 2 ) + gamma * head_dim )
77+ / ((1.0 + gamma ) * head_dim )
78+ ),
79+ )
80+ self .max_seq_len_cached = None
81+ self .precision = precision
82+ self .freq_base = freq_base
83+
84+ def forward (self , x , seq_dim = 1 , seq_len = None ):
85+ if seq_len is None :
86+ seq_len = x .shape [seq_dim ]
87+ if (
88+ self .max_seq_len_cached is None
89+ or (seq_len > self .max_seq_len_cached )
90+ ):
91+ self .max_seq_len_cached = seq_len
92+ scale = (
93+ self .scale
94+ ** (
95+ torch .arange (0 , seq_len , 1 ) - seq_len // 2
96+ ).to (self .scale ).div (self .scale_base )[:, None ]
97+ )
98+ cos , sin = fixed_pos_embedding (scale , self .freq_base )
99+ self .cos_cached = cos
100+ self .sin_cached = sin
101+ self .scale_cached = scale
102+ if self .precision == torch .bfloat16 :
103+ self .cos_cached = self .cos_cached .bfloat16 ()
104+ self .sin_cached = self .sin_cached .bfloat16 ()
105+ return (
106+ self .cos_cached [:seq_len ],
107+ self .sin_cached [:seq_len ],
108+ self .scale_cached [:seq_len ],
109+ )
110+
111+
112+ def rotate_every_two (x ):
113+ x1 = x [:, :, ::2 ]
114+ x2 = x [:, :, 1 ::2 ]
115+ x = torch .stack ((- x2 , x1 ), dim = - 1 )
116+ return x .flatten (- 2 ) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\
117+
118+
119+ def duplicate_interleave (m ):
120+ """
121+ A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy.
122+ """
123+ dim0 = m .shape [0 ]
124+ m = m .view (- 1 , 1 ) # flatten the matrix
125+ m = m .repeat (1 , 2 ) # repeat all elements into the 2nd dimension
126+ m = m .view (dim0 , - 1 ) # reshape into a matrix, interleaving the copy
127+ return m .unsqueeze (1 )
128+
129+
130+ def _apply_xpos_emb (x , cos , sin , scale ):
131+ # x is assumed to be (seq_len, batch_size, dim) here.
132+ cos = duplicate_interleave (cos * scale )
133+ sin = duplicate_interleave (sin * scale )
134+ # einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2)
135+ return (x * cos ) + (rotate_every_two (x ) * sin )
136+
137+
138+ @torch .jit .script
139+ def apply_xpos_emb (q , k , cos , sin , scale , offset : int = 0 ):
140+ # q/k are assumed to be (seq_len, batch_size, dim) here.
141+ cos = cos [offset :q .shape [0 ] + offset ]
142+ sin = sin [offset :q .shape [0 ] + offset ]
143+ scale = scale [offset :q .shape [0 ] + offset ]
144+ return (
145+ _apply_xpos_emb (q , cos , sin , scale ),
146+ _apply_xpos_emb (q , cos , sin , 1.0 / scale ),
147+ )
148+
149+
150+ def apply_xpos_emb_torch (q , k , cos , sin , scale , offset : int = 0 ):
151+ # q/k are assumed to be (seq_len, batch_size, dim) here.
152+ cos = cos [offset :q .shape [0 ] + offset ]
153+ sin = sin [offset :q .shape [0 ] + offset ]
154+ scale = scale [offset :q .shape [0 ] + offset ]
155+ return (
156+ _apply_xpos_emb (q , cos , sin , scale ),
157+ _apply_xpos_emb (q , cos , sin , 1.0 / scale ),
158+ )
0 commit comments