22prompt = "Lily picked up a flower."
33model_name = "Maykeye/TinyLLama-v0"
44
5- captured_input = None # type: ignore[var-annotated]
5+ captured_input = ()
66
77import copy , inspect , types
88
@@ -20,7 +20,9 @@ def capture_and_forward(self, *args, **kwargs):
2020 args_names = [
2121 # signature includes `self`` and `kwargs``.
2222 # Just retrieve the ordinary positional inputs only
23- name for name in sig .parameters .keys () if name not in ("self" , "kwargs" )
23+ name
24+ for name in sig .parameters .keys ()
25+ if name not in ("self" , "kwargs" )
2426 ]
2527
2628 args_dict = dict (zip (args_names , args ))
@@ -32,8 +34,8 @@ def populate_args(args_dict, filter):
3234 args_tuple = tuple (args_dict .get (name , None ) for name in args_names )
3335 return copy .deepcopy (args_tuple )
3436
35- if len (args_dict [' past_key_value' ].key_cache ) != 0 :
36- input_to_remove = [ "use_cache" ]
37+ if len (args_dict [" past_key_value" ].key_cache ) != 0 :
38+ input_to_remove = ["use_cache" ]
3739 captured_input = populate_args (args_dict , input_to_remove )
3840
3941 return forward_old (self , * args , ** kwargs )
@@ -61,7 +63,9 @@ def populate_args(args_dict, filter):
6163
6264model = AutoModelForCausalLM .from_pretrained (model_name )
6365model .eval ()
64- model .model .layers [0 ].forward = types .MethodType (capture_and_forward , model .model .layers [0 ])
66+ model .model .layers [0 ].forward = types .MethodType (
67+ capture_and_forward , model .model .layers [0 ]
68+ )
6569with torch .no_grad ():
6670 outputs = model .generate (
6771 ** inputs ,
0 commit comments