1+ import itertools
12import unittest
23from collections import defaultdict
34
@@ -19,13 +20,18 @@ def setUp(self):
1920 torch .manual_seed (42 )
2021
2122 def test_without_cache (self ):
22- def test (use_qk_norm , use_conv2d ):
23+ def test (use_qk_norm , qk_norm_before_rope , adopt_hf_rope , use_conv2d ):
24+ if not use_qk_norm and qk_norm_before_rope :
25+ # Redundant test.
26+ return
27+
2328 config = ModelArgs (
2429 dim = 64 ,
2530 n_heads = 4 ,
2631 n_kv_heads = 2 ,
2732 max_seq_len = 8 ,
2833 use_qk_norm = use_qk_norm ,
34+ qk_norm_before_rope = qk_norm_before_rope ,
2935 )
3036 layer_id = 0
3137 rope = Rope (config )
@@ -40,12 +46,19 @@ def test(use_qk_norm, use_conv2d):
4046 torch .rand (config .head_dim ) * 0.2 + 0.9
4147 )
4248 static_attn .load_weights_from_attention_mha (attn_mha )
49+ if adopt_hf_rope :
50+ static_attn .adopt_hf_rope ()
4351 if use_conv2d :
4452 static_attn .linear_to_conv2d ()
4553
4654 x = torch .rand (1 , config .max_seq_len , config .dim )
4755 freqs_cos , freqs_sin = rope .get_freqs (None , config .max_seq_len )
4856 expected , _ = attn_mha (x , freqs_cos , freqs_sin )
57+
58+ if adopt_hf_rope :
59+ config .use_hf_rope = True
60+ rope = Rope (config )
61+ freqs_cos , freqs_sin = rope .get_freqs (None , config .max_seq_len )
4962 mask = torch .triu (
5063 torch .full ((1 , config .max_seq_len , config .max_seq_len ), float ("-inf" )),
5164 diagonal = 1 ,
@@ -56,45 +69,16 @@ def test(use_qk_norm, use_conv2d):
5669 freqs_sin ,
5770 mask = mask ,
5871 )
59- self .assertTrue (torch .isclose (y , expected , rtol = 1e-3 ).all ())
60-
61- test (True , True )
62- test (True , False )
63- test (False , True )
64- test (False , False )
65-
66- def test_hf_rope_without_cache (self ):
67- config = ModelArgs (
68- dim = 64 ,
69- n_heads = 4 ,
70- n_kv_heads = 2 ,
71- max_seq_len = 8 ,
72- use_qk_norm = True ,
73- use_hf_rope = True ,
74- )
75- layer_id = 0
76- rope = Rope (config )
77- attn_mha = AttentionMHA (config , layer_id , rope ).eval ()
78- with torch .no_grad ():
79- attn_mha .q_norm_fn .weight .copy_ (torch .rand (config .head_dim ) * 0.2 + 0.9 )
80- attn_mha .k_norm_fn .weight .copy_ (torch .rand (config .head_dim ) * 0.2 + 0.9 )
81- static_attn = StaticAttention (config , layer_id , rope ).eval ()
82- static_attn .load_weights_from_attention_mha (attn_mha )
72+ self .assertTrue (
73+ torch .isclose (y , expected , rtol = 1e-3 ).all (),
74+ f"Failed for use_qk_norm={ use_qk_norm } , "
75+ f"qk_norm_before_rope={ qk_norm_before_rope } , "
76+ f"adopt_hf_rope={ adopt_hf_rope } , "
77+ f"use_conv2d={ use_conv2d } " ,
78+ )
8379
84- x = torch .rand (1 , config .max_seq_len , config .dim )
85- freqs_cos , freqs_sin = rope .get_freqs (None , config .max_seq_len )
86- expected , _ = attn_mha (x , freqs_cos , freqs_sin )
87- mask = torch .triu (
88- torch .full ((1 , config .max_seq_len , config .max_seq_len ), float ("-inf" )),
89- diagonal = 1 ,
90- )
91- y , _ = static_attn (
92- x ,
93- freqs_cos .unsqueeze (0 ),
94- freqs_sin .unsqueeze (0 ),
95- mask = mask ,
96- )
97- self .assertTrue (torch .isclose (y , expected , rtol = 1e-3 ).all ())
80+ for args in itertools .product ([False , True ], repeat = 4 ):
81+ test (* args )
9882
9983 def test_with_cache (self ):
10084 config = ModelArgs (
@@ -108,6 +92,7 @@ def test_with_cache(self):
10892 attn_mha = AttentionMHA (config , layer_id , rope ).eval ()
10993 static_attn = StaticAttention (config , layer_id , rope ).eval ()
11094 static_attn .load_weights_from_attention_mha (attn_mha )
95+ static_attn .adopt_hf_rope ()
11196
11297 x = torch .rand (1 , config .max_seq_len , config .dim )
11398 freqs_cos , freqs_sin = rope .get_freqs (None , config .max_seq_len )
@@ -117,6 +102,10 @@ def test_with_cache(self):
117102 chunk_len = config .max_seq_len // n_chunks
118103 cache_len = config .max_seq_len - chunk_len
119104
105+ config .use_hf_rope = True
106+ hf_rope = Rope (config )
107+ hf_freqs_cos , hf_freqs_sin = hf_rope .get_freqs (None , config .max_seq_len )
108+
120109 def test_with_style (style ):
121110 mask = StaticAttentionMask (chunk_len , cache_len , style = style )
122111 mask .tensor [:, :, cache_len :] = torch .triu (
@@ -139,8 +128,8 @@ def test_with_style(style):
139128 for i in range (n_chunks ):
140129 y_i , attn_update = static_attn (
141130 x [:, i * chunk_len : (i + 1 ) * chunk_len , :],
142- freqs_cos [i * chunk_len : (i + 1 ) * chunk_len ],
143- freqs_sin [i * chunk_len : (i + 1 ) * chunk_len ],
131+ hf_freqs_cos [i * chunk_len : (i + 1 ) * chunk_len ],
132+ hf_freqs_sin [i * chunk_len : (i + 1 ) * chunk_len ],
144133 mask = mask .tensor ,
145134 in_cache_state = (k_caches , v_caches ),
146135 out_cache_state = ({}, {}),
@@ -175,6 +164,7 @@ def _get_test_transformers(self, config):
175164 mha_transformer .layers , static_transformer .layers
176165 ):
177166 static_layer .attention .load_weights_from_attention_mha (mha_layer .attention )
167+ static_layer .attention .adopt_hf_rope ()
178168
179169 return mha_transformer , static_transformer
180170
@@ -196,6 +186,7 @@ def test_within_transformer(self):
196186 cache_len = config .max_seq_len - chunk_len
197187
198188 def test_with_style (style ):
189+ config .use_hf_rope = True
199190 mgr = StaticAttentionIOManager (config , chunk_len , cache_len , style = style )
200191 ys = []
201192 for i in range (n_chunks ):
@@ -222,6 +213,7 @@ def test_lookahead_decode(self):
222213 )
223214 _ , static_transformer = self ._get_test_transformers (config )
224215
216+ config .use_hf_rope = True
225217 input_len = 32
226218 cache_len = config .max_seq_len - input_len
227219 prefill_input = torch .randint (config .vocab_size , (input_len ,))
0 commit comments