@@ -41,15 +41,15 @@ def setUp(self):
4141        )
4242    )
4343
44-   def  _get_params (self , enable_hlfb : bool ):
44+   def  _get_params (self , enable_hlfb : bool ,  kv_layout :  kv_cache . KVLayout ):
4545    """Returns a model, edge model and the kwargs to use for testing.""" 
4646    config  =  toy_model_with_kv_cache .get_model_config ()
4747    config .enable_hlfb  =  enable_hlfb 
4848    pytorch_model  =  toy_model_with_kv_cache .ToyModelWithKVCache (config ).eval ()
4949    tokens , input_pos  =  torch .tensor ([[1 ]], dtype = torch .int ), torch .tensor (
5050        [10 ], dtype = torch .int 
5151    )
52-     kv  =  kv_cache .KVCache .from_model_config (config )
52+     kv  =  kv_cache .KVCache .from_model_config (config ,  kv_layout = kv_layout )
5353    kwargs  =  {
5454        "tokens" : tokens ,
5555        "input_pos" : input_pos ,
@@ -65,8 +65,12 @@ def _get_params(self, enable_hlfb: bool):
6565    )
6666    return  pytorch_model , edge_model , kwargs 
6767
68-   def  _test_model_with_kv_cache (self , enable_hlfb : bool ):
69-     pytorch_model , edge_model , kwargs  =  self ._get_params (enable_hlfb )
68+   def  _test_model_with_kv_cache (
69+       self ,
70+       enable_hlfb : bool  =  False ,
71+       kv_layout : kv_cache .KVLayout  =  kv_cache .KV_LAYOUT_DEFAULT ,
72+   ):
73+     pytorch_model , edge_model , kwargs  =  self ._get_params (enable_hlfb , kv_layout )
7074
7175    self .assertTrue (
7276        test_utils .compare_tflite_torch (
@@ -95,13 +99,22 @@ def test_toy_model_with_kv_cache(self):
9599  def  test_toy_model_with_kv_cache_with_hlfb (self ):
96100    self ._test_model_with_kv_cache (enable_hlfb = True )
97101
102+   @googletest .skipIf ( 
103+       ai_edge_torch .config .in_oss , 
104+       reason = "tests with custom ops are not supported in oss" , 
105+   ) 
106+   def  test_toy_model_with_kv_cache_transposed (self ):
107+     self ._test_model_with_kv_cache (kv_layout = kv_cache .KV_LAYOUT_TRANSPOSED )
108+ 
98109  @googletest .skipIf ( 
99110      ai_edge_torch .config .in_oss , 
100111      reason = "tests with custom ops are not supported in oss" , 
101112  ) 
102113  def  test_toy_model_has_dus_op (self ):
103114    """Tests that the model has the dynamic update slice op.""" 
104-     _ , edge_model , _  =  self ._get_params (enable_hlfb = True )
115+     _ , edge_model , _  =  self ._get_params (
116+         enable_hlfb = True , kv_layout = kv_cache .KV_LAYOUT_DEFAULT 
117+     )
105118    interpreter_  =  interpreter .InterpreterWithCustomOps (
106119        custom_op_registerers = ["GenAIOpsRegisterer" ],
107120        model_content = edge_model .tflite_model (),
@@ -112,7 +125,14 @@ def test_toy_model_has_dus_op(self):
112125    op_names  =  [op ["op_name" ] for  op  in  interpreter_ ._get_ops_details ()]
113126    self .assertIn ("DYNAMIC_UPDATE_SLICE" , op_names )
114127
115-   def  _test_multisig_model (self , config , pytorch_model , atol , rtol ):
128+   def  _test_multisig_model (
129+       self ,
130+       config ,
131+       pytorch_model ,
132+       atol ,
133+       rtol ,
134+       kv_layout = kv_cache .KV_LAYOUT_DEFAULT ,
135+   ):
116136    # prefill 
117137    seq_len  =  10 
118138    prefill_tokens  =  torch .zeros ((1 , seq_len ), dtype = torch .int , device = "cpu" )
@@ -124,7 +144,7 @@ def _test_multisig_model(self, config, pytorch_model, atol, rtol):
124144    decode_token  =  torch .tensor ([[1 ]], dtype = torch .int )
125145    decode_input_pos  =  torch .tensor ([5 ], dtype = torch .int )
126146
127-     kv  =  kv_cache .KVCache .from_model_config (config )
147+     kv  =  kv_cache .KVCache .from_model_config (config ,  kv_layout = kv_layout )
128148
129149    edge_model  =  (
130150        ai_edge_torch .signature (
@@ -160,7 +180,7 @@ def _test_multisig_model(self, config, pytorch_model, atol, rtol):
160180            kv ,
161181            signature_name = "prefill" ,
162182            atol = atol ,
163-             rtol = atol ,
183+             rtol = rtol ,
164184        )
165185    )
166186
@@ -173,7 +193,7 @@ def _test_multisig_model(self, config, pytorch_model, atol, rtol):
173193            kv ,
174194            signature_name = "decode" ,
175195            atol = atol ,
176-             rtol = atol ,
196+             rtol = rtol ,
177197        )
178198    )
179199
@@ -186,6 +206,21 @@ def test_tiny_llama_multisig(self):
186206    pytorch_model  =  tiny_llama .TinyLlama (config ).eval ()
187207    self ._test_multisig_model (config , pytorch_model , atol = 1e-5 , rtol = 1e-5 )
188208
209+   @googletest .skipIf ( 
210+       ai_edge_torch .config .in_oss , 
211+       reason = "tests with custom ops are not supported in oss" , 
212+   ) 
213+   def  test_tiny_llama_multisig_kv_layout_transposed (self ):
214+     config  =  tiny_llama .get_fake_model_config ()
215+     pytorch_model  =  tiny_llama .TinyLlama (config ).eval ()
216+     self ._test_multisig_model (
217+         config ,
218+         pytorch_model ,
219+         atol = 1e-5 ,
220+         rtol = 1e-5 ,
221+         kv_layout = kv_cache .KV_LAYOUT_TRANSPOSED ,
222+     )
223+ 
189224
190225if  __name__  ==  "__main__" :
191226  googletest .main ()
0 commit comments