Skip to content

Commit 5bc2ce6

Browse files
author
Sanggyu Lee
committed
Make format happy
1 parent 55fa695 commit 5bc2ce6

File tree

1 file changed

+9
-5
lines changed
  • test/modules/model/LlamaDecoderLayerWithCache

1 file changed

+9
-5
lines changed

test/modules/model/LlamaDecoderLayerWithCache/model.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
prompt = "Lily picked up a flower."
33
model_name = "Maykeye/TinyLLama-v0"
44

5-
captured_input = None # type: ignore[var-annotated]
5+
captured_input = ()
66

77
import copy, inspect, types
88

@@ -20,7 +20,9 @@ def capture_and_forward(self, *args, **kwargs):
2020
args_names = [
2121
# signature includes `self`` and `kwargs``.
2222
# Just retrieve the ordinary positional inputs only
23-
name for name in sig.parameters.keys() if name not in ("self", "kwargs")
23+
name
24+
for name in sig.parameters.keys()
25+
if name not in ("self", "kwargs")
2426
]
2527

2628
args_dict = dict(zip(args_names, args))
@@ -32,8 +34,8 @@ def populate_args(args_dict, filter):
3234
args_tuple = tuple(args_dict.get(name, None) for name in args_names)
3335
return copy.deepcopy(args_tuple)
3436

35-
if len(args_dict['past_key_value'].key_cache) != 0:
36-
input_to_remove = [ "use_cache" ]
37+
if len(args_dict["past_key_value"].key_cache) != 0:
38+
input_to_remove = ["use_cache"]
3739
captured_input = populate_args(args_dict, input_to_remove)
3840

3941
return forward_old(self, *args, **kwargs)
@@ -61,7 +63,9 @@ def populate_args(args_dict, filter):
6163

6264
model = AutoModelForCausalLM.from_pretrained(model_name)
6365
model.eval()
64-
model.model.layers[0].forward = types.MethodType(capture_and_forward, model.model.layers[0])
66+
model.model.layers[0].forward = types.MethodType(
67+
capture_and_forward, model.model.layers[0]
68+
)
6569
with torch.no_grad():
6670
outputs = model.generate(
6771
**inputs,

0 commit comments

Comments
 (0)