You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
826
861
dtype (`torch.dtype`):
827
862
The dtype to use for the 4D attention mask.
828
-
past_key_values (`Cache`):
829
-
The cache class that is being used currently to generate
830
863
cache_position (`torch.LongTensor`):
831
864
Indices depicting the position of the input sequence tokens in the sequence.
832
865
batch_size (`int`):
@@ -1199,14 +1232,29 @@ def forward(
1199
1232
ifposition_idsisNone:
1200
1233
assertattention_maskisNoneorattention_mask.ndim==2, "attention mask must be 2D"
1201
1234
# calculate RoPE index once per generation in the pre-fill stage only
1202
-
position_ids, rope_deltas=self.get_rope_index(
1203
-
input_ids,
1204
-
image_grid_thw,
1205
-
video_grid_thw,
1206
-
second_per_grid_ts,
1207
-
attention_mask,
1235
+
is_prefill= (
1236
+
(cache_positionisNoneorcache_position[0] ==0)
1237
+
orself.rope_deltasisNone
1238
+
orpast_key_valuesisNone
1208
1239
)
1209
-
self.rope_deltas=rope_deltas
1240
+
ifis_prefill:
1241
+
position_ids, rope_deltas=self.get_rope_index(
1242
+
input_ids,
1243
+
image_grid_thw,
1244
+
video_grid_thw,
1245
+
second_per_grid_ts,
1246
+
attention_mask,
1247
+
)
1248
+
self.rope_deltas=rope_deltas
1249
+
# then use the prev pre-calculated rope-deltas to get the correct position ids
0 commit comments