Skip to content

Commit ed906cb

Browse files
committed
Kv cache
1 parent e145bd1 commit ed906cb

File tree

1 file changed

+20
-6
lines changed

1 file changed

+20
-6
lines changed

examples/models/llama3_2_vision/model.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,14 @@ class Llama3_2Decoder(EagerModelBase):
4040

4141
def __init__(self, **kwargs):
4242
# Set member vars from kwargs.
43-
self.max_seq_len = kwargs.get("max_seq_len", 8192)
43+
self.max_seq_len = kwargs.get("max_seq_len", 8192) # Trained to be a lot larger, but this value is kept small because of static kv cache at the moment.
4444
self.encoder_max_seq_len = kwargs.get(
4545
"encoder_max_seq_len", int(4 * (448 / 14) ** 2 + 1)
46-
)
46+
) # Same as above.
4747
self.generate_full_logits = kwargs.get("generate_full_logits", False)
4848
self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False)
4949
self.output_prune_map_path = kwargs.get("output_prune_map_path", None)
50-
# TODO: enable kv cache with TransformerDecoder's setup_cache().
5150
self.use_kv_cache = kwargs.get("use_kv_cache", False)
52-
self.use_sdpa_with_kv_cache = kwargs.get("use_sdpa_with_kv_cache", False)
5351
self.verbose = kwargs.get("verbose", False)
5452
self.args = kwargs.get("args", None)
5553

@@ -60,6 +58,14 @@ def __init__(self, **kwargs):
6058
checkpoint_dir = kwargs.get("checkpoint_dir", None)
6159
params_path = kwargs.get("params", ckpt_dir / "demo_config.json")
6260

61+
self.causal_mask = torch.tril(
62+
torch.ones(
63+
size=(self.max_seq_len, self.max_seq_len),
64+
dtype=torch.bool,
65+
)
66+
)
67+
self.input_pos = torch.arange(self.max_seq_len)
68+
6369
# Load checkpoint and params.
6470
device = "cpu"
6571
if checkpoint_dir is not None:
@@ -126,22 +132,30 @@ def __init__(self, **kwargs):
126132

127133
self.model_ = prune_output_vocab(self.model_, output_prune_map)
128134

135+
# if self.use_kv_cache:
136+
# print("Setting up KV cache on the model...")
137+
# self.model_.setup_caches(
138+
# batch_size=1,
139+
# dtype=self.dtype,
140+
# )
141+
129142
def get_eager_model(self) -> torch.nn.Module:
130143
if self.dtype:
131144
return self.model_.to(self.dtype)
132145
else:
133146
return self.model_.to(torch.float16)
134147

135148
def get_example_inputs(self):
136-
return (torch.ones(1, 64, dtype=torch.long),) # positional inputs
149+
return (torch.ones(1, 64, dtype=torch.long),)
137150

138151
def get_example_kwarg_inputs(self):
139152
# TODO: add input_pos and mask when after making cache work.
140153
return {
141-
# "mask": None,
154+
# "mask": self.causal_mask[None, 64, None, :],
142155
# "encoder_input": None,
143156
# "encoder_mask": None,
144157
# "input_pos": torch.ones(64, dtype=torch.long),
158+
# input_pos: self.input_pos[None, 64]
145159
}
146160

147161
def get_dynamic_shapes(self):

0 commit comments

Comments
 (0)