44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ import tempfile
78import unittest
89from typing import Callable , Tuple
910
1011import torch
11-
1212from executorch .exir import EdgeCompileConfig , to_edge
13+
14+ from executorch .extension .export_util .utils import save_pte_program
1315from executorch .extension .llm .modules .kv_cache import KVCache as InferenceKVCache
16+
17+ from executorch .extension .pybindings .portable_lib import (
18+ _load_for_executorch_from_buffer ,
19+ )
1420from executorch .runtime import Runtime
1521from torch .testing import assert_close
1622from torchtune .modules .kv_cache import KVCache
@@ -67,21 +73,10 @@ def _test_kv_cache(self, et_cache_module: Callable):
6773 prefill_seq_len , self .batch_size , self .num_kv_heads , self .head_dim
6874 )
6975
70- print ()
71- print ("Prefilling..." )
72- print ()
73-
7476 et_res = et_cache_module (k_val , v_val )
7577 tt_res = self .tt_kv_cache .update (k_val_trans , v_val_trans )
7678 tt_res_transposed = (tt_res [0 ].transpose (1 , 2 ), tt_res [1 ].transpose (1 , 2 ))
7779
78- print ()
79- print ("Final tt kv_cache.cache_pos" )
80- print (self .tt_kv_cache .cache_pos )
81- print ("Final tt kv_cache.k_cache" )
82- print (self .tt_kv_cache .k_cache )
83- print ()
84-
8580 # Check torchtune matches executorch.
8681 assert_close (et_res , tt_res_transposed )
8782
@@ -112,7 +107,6 @@ def _test_kv_cache(self, et_cache_module: Callable):
112107
113108 self .assertTrue (et_k_cache [0 ][prefill_seq_len + 1 ][0 ][0 ] == 0 )
114109
115-
116110 def export_kv_cache (
117111 self ,
118112 kv_cache : torch .nn .Module ,
@@ -179,9 +173,6 @@ def test_kv_cache_executorch(self):
179173 )
180174 et_program = edge_program .to_executorch ()
181175
182- """DEBUG the executorch program"""
183- et_program .dump_executorch_program (verbose = True )
184-
185176 runtime = Runtime .get ()
186177 program = runtime .load_program (et_program .buffer )
187178 method = program .load_method ("forward" )
@@ -192,3 +183,27 @@ def wrapped_callable(k_val: torch.Tensor, v_val: torch.Tensor) -> torch.Tensor:
192183
193184 self ._test_kv_cache (wrapped_callable )
194185
186+ def test_kv_cache_executorch_from_file (self ):
187+ exported_kv_cache = self .export_kv_cache (self .et_kv_cache )
188+ edge_program = to_edge (
189+ exported_kv_cache ,
190+ compile_config = EdgeCompileConfig (
191+ _core_aten_ops_exception_list = [torch .ops .aten ._assert_async .msg ],
192+ _check_ir_validity = False ,
193+ ),
194+ )
195+ et_program = edge_program .to_executorch ()
196+
197+ with tempfile .TemporaryDirectory () as tempdir :
198+ pte_path = save_pte_program (et_program , "test_et_kv_cache" , tempdir )
199+ with open (pte_path , "rb" ) as f :
200+ model_bytes = f .read ()
201+ loaded_et_program = _load_for_executorch_from_buffer (model_bytes )
202+
203+ # Since method.execute expects a tuple of args.
204+ def wrapped_callable (
205+ k_val : torch .Tensor , v_val : torch .Tensor
206+ ) -> torch .Tensor :
207+ return loaded_et_program .forward ((k_val , v_val ))
208+
209+ self ._test_kv_cache (wrapped_callable )
0 commit comments