Skip to content

Commit d1cbf71

Browse files
committed
[feat] Add kv cache for InternVLA-N1 realworld deployment
1 parent b0de752 commit d1cbf71

File tree

2 files changed

+66
-5
lines changed

2 files changed

+66
-5
lines changed

internnav/agent/internvla_n1_agent_realworld.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -202,10 +202,10 @@ def step_s2(self, rgb, depth, pose, instruction, intrinsic, look_down = False):
202202
**inputs,
203203
max_new_tokens=128,
204204
do_sample=False,
205-
# use_cache=True,
206-
# past_key_values=self.past_key_values,
205+
use_cache=True,
206+
past_key_values=self.past_key_values,
207207
return_dict_in_generate=True,
208-
# raw_input_ids=copy.deepcopy(inputs.input_ids),
208+
raw_input_ids=copy.deepcopy(inputs.input_ids),
209209
)
210210
output_ids = outputs.sequences
211211

internnav/model/basemodel/internvla_n1/internvla_n1.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,49 @@ def __init__(self, config):
101101
def get_model(self):
102102
return self.model
103103

104+
def prepare_inputs_for_generation(
105+
self,
106+
input_ids,
107+
past_key_values=None,
108+
attention_mask=None,
109+
inputs_embeds=None,
110+
cache_position=None,
111+
position_ids=None,
112+
use_cache=True,
113+
pixel_values=None,
114+
pixel_values_videos=None,
115+
image_grid_thw=None,
116+
video_grid_thw=None,
117+
second_per_grid_ts=None,
118+
**kwargs,
119+
):
120+
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
121+
122+
model_inputs = super().prepare_inputs_for_generation(
123+
input_ids,
124+
past_key_values=past_key_values,
125+
attention_mask=attention_mask,
126+
inputs_embeds=inputs_embeds,
127+
cache_position=cache_position,
128+
position_ids=position_ids,
129+
pixel_values=pixel_values,
130+
pixel_values_videos=pixel_values_videos,
131+
image_grid_thw=image_grid_thw,
132+
video_grid_thw=video_grid_thw,
133+
second_per_grid_ts=second_per_grid_ts,
134+
use_cache=use_cache,
135+
**kwargs,
136+
)
137+
# Qwen2-5-VL position_ids are prepareed with rope_deltas in forward
138+
model_inputs["position_ids"] = None
139+
140+
# add for QwenVL kv cache
141+
model_inputs["pixel_values"] = pixel_values
142+
model_inputs["pixel_values_videos"] = pixel_values_videos
143+
144+
return model_inputs
145+
146+
104147
def forward(
105148
self,
106149
input_ids: Optional[torch.LongTensor] = None,
@@ -121,6 +164,7 @@ def forward(
121164
rope_deltas: Optional[torch.LongTensor] = None,
122165
cache_position: Optional[torch.LongTensor] = None,
123166
second_per_grid_ts: Optional[torch.Tensor] = None,
167+
raw_input_ids: Optional[torch.LongTensor] = None,
124168
) -> Union[Tuple, CausalLMOutputWithPast]:
125169
r"""
126170
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -169,10 +213,11 @@ def forward(
169213

170214
if inputs_embeds is None:
171215
inputs_embeds = self.model.embed_tokens(input_ids)
172-
if pixel_values is not None:
216+
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
217+
if pixel_values is not None and n_image_tokens > 0:
173218
pixel_values = pixel_values.type(self.visual.dtype)
174219
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
175-
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
220+
image_embeds = image_embeds[-n_image_tokens:]
176221
n_image_features = image_embeds.shape[0]
177222
if n_image_tokens != n_image_features:
178223
raise ValueError(
@@ -232,6 +277,22 @@ def forward(
232277
attention_mask,
233278
)
234279
self.rope_deltas = rope_deltas
280+
elif n_image_tokens > 0: # using only for kv cache
281+
attention_mask = attention_mask[:, :raw_input_ids.shape[1]]
282+
position_ids, rope_deltas = self.get_rope_index(
283+
raw_input_ids,
284+
image_grid_thw,
285+
video_grid_thw,
286+
second_per_grid_ts,
287+
attention_mask,
288+
)
289+
delta = (
290+
(cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
291+
if cache_position is not None
292+
else 0
293+
)
294+
position_ids = position_ids[:, :,-input_ids.shape[1]:]
295+
self.rope_deltas = rope_deltas
235296
# then use the prev pre-calculated rope-deltas to get the correct position ids
236297
else:
237298
batch_size, seq_length, _ = inputs_embeds.shape

0 commit comments

Comments
 (0)