@@ -205,9 +205,13 @@ def replace_kv_cache_with_quantized_kv_cache(module):
205205 # This is needed to ensure that custom ops are registered
206206 from executorch .extension .llm .custom_ops import custom_ops # noqa: F401
207207
208- logging .warning (
208+ logging .info (
209209 "Replacing KVCache with QuantizedKVCache. This modifies the model in place."
210210 )
211+ return _replace_kv_cache_with_quantized_kv_cache (module )
212+
213+
214+ def _replace_kv_cache_with_quantized_kv_cache (module ):
211215 for name , child in module .named_children ():
212216 if isinstance (child , KVCache ) or isinstance (child , CustomKVCache ):
213217 setattr (
@@ -220,7 +224,7 @@ def replace_kv_cache_with_quantized_kv_cache(module):
220224 ),
221225 )
222226 else :
223- replace_kv_cache_with_quantized_kv_cache (child )
227+ _replace_kv_cache_with_quantized_kv_cache (child )
224228 return module
225229
226230
@@ -263,16 +267,20 @@ def update(
263267
264268
265269def replace_kv_cache_with_custom_kv_cache (module ):
266- r """
270+ """
267271 Replace KVCache with CustomKVCache. This modifies the model in place.
268272 At the moment custom kv cache only supports cache with shape
269273 [B, S, H, D] as opposed to [B, H, S, D]
270274 This is because the custom op treats second dim as sequence dim.
271275 Future work: support [B, H, S, D]
272276 """
273- logging .warning (
277+ logging .info (
274278 "Replacing KVCache with CustomKVCache. This modifies the model in place."
275279 )
280+ return _replace_kv_cache_with_custom_kv_cache (module )
281+
282+
283+ def _replace_kv_cache_with_custom_kv_cache (module ):
276284 for name , child in module .named_children ():
277285 if isinstance (child , KVCache ):
278286 cache_shape = child .k_cache .shape
@@ -290,5 +298,5 @@ def replace_kv_cache_with_custom_kv_cache(module):
290298 ),
291299 )
292300 else :
293- replace_kv_cache_with_custom_kv_cache (child )
301+ _replace_kv_cache_with_custom_kv_cache (child )
294302 return module
0 commit comments