@@ -67,10 +67,21 @@ def _test_kv_cache(self, et_cache_module: Callable):
6767 prefill_seq_len , self .batch_size , self .num_kv_heads , self .head_dim
6868 )
6969
70+ print ()
71+ print ("Prefilling..." )
72+ print ()
73+
7074 et_res = et_cache_module (k_val , v_val )
7175 tt_res = self .tt_kv_cache .update (k_val_trans , v_val_trans )
7276 tt_res_transposed = (tt_res [0 ].transpose (1 , 2 ), tt_res [1 ].transpose (1 , 2 ))
7377
78+ print ()
79+ print ("Final tt kv_cache.cache_pos" )
80+ print (self .tt_kv_cache .cache_pos )
81+ print ("Final tt kv_cache.k_cache" )
82+ print (self .tt_kv_cache .k_cache )
83+ print ()
84+
7485 # Check torchtune matches executorch.
7586 assert_close (et_res , tt_res_transposed )
7687
@@ -89,17 +100,19 @@ def _test_kv_cache(self, et_cache_module: Callable):
89100
90101 et_res = et_cache_module (k_val , v_val )
91102 tt_res = self .tt_kv_cache .update (k_val_trans , v_val_trans )
103+ tt_res_transposed = (tt_res [0 ].transpose (1 , 2 ), tt_res [1 ].transpose (1 , 2 ))
92104
93105 # Check torchtune matches executorch.
94- tt_res_transposed = (tt_res [0 ].transpose (1 , 2 ), tt_res [1 ].transpose (1 , 2 ))
95106 assert_close (tt_res_transposed , et_res )
96107
97108 # All rows should be filled with 1s up to 3 + 1th row.
98109 et_k_cache = et_res [0 ]
99110 for i in range (prefill_seq_len + 1 ):
100111 self .assertTrue (et_k_cache [0 ][i ][0 ][0 ] == 1 )
112+
101113 self .assertTrue (et_k_cache [0 ][prefill_seq_len + 1 ][0 ][0 ] == 0 )
102114
115+
103116 def export_kv_cache (
104117 self ,
105118 kv_cache : torch .nn .Module ,
@@ -165,6 +178,10 @@ def test_kv_cache_executorch(self):
165178 ),
166179 )
167180 et_program = edge_program .to_executorch ()
181+
182+ """DEBUG the executorch program"""
183+ et_program .dump_executorch_program (verbose = True )
184+
168185 runtime = Runtime .get ()
169186 program = runtime .load_program (et_program .buffer )
170187 method = program .load_method ("forward" )
@@ -174,3 +191,4 @@ def wrapped_callable(k_val: torch.Tensor, v_val: torch.Tensor) -> torch.Tensor:
174191 return method .execute ((k_val , v_val ))
175192
176193 self ._test_kv_cache (wrapped_callable )
194+
0 commit comments