@@ -11,7 +11,7 @@ def _deepstack_add_kernel(
1111 Out ,
1212 Img_token_lens ,
1313 Img_start_token_ids ,
14- Img_start_locs ,
14+ Img_start_locs_in_cache ,
1515 stride_deep_s ,
1616 stride_deep_d ,
1717 stride_out_s ,
@@ -26,8 +26,8 @@ def _deepstack_add_kernel(
2626 off_d = tl .arange (0 , BLOCK_DIM )
2727
2828 img_start_token_id = tl .load (Img_start_token_ids + img_handle_id )
29- img_start_loc = tl .load (Img_start_locs + img_handle_id )
3029 img_token_len = tl .load (Img_token_lens + img_handle_id )
30+ img_start_loc_in_cache = tl .load (Img_start_locs_in_cache + img_handle_id )
3131
3232 # 判断当前 token 是否属于这个 image
3333 cond = (token_id >= img_start_token_id ) & (token_id < img_start_token_id + img_token_len )
@@ -36,7 +36,7 @@ def _deepstack_add_kernel(
3636 token_offset = token_id - img_start_token_id
3737
3838 deep_row = tl .load (
39- Deepstack_embs + stride_deep_s * (img_start_loc + token_offset ) + off_d ,
39+ Deepstack_embs + stride_deep_s * (img_start_loc_in_cache + token_offset ) + off_d ,
4040 mask = off_d < hidden_size ,
4141 other = 0 ,
4242 )
@@ -60,7 +60,7 @@ def add_deepstack_embs(
6060 deepstack_embs : torch .Tensor ,
6161 img_token_lens : torch .Tensor ,
6262 img_start_token_ids : torch .Tensor ,
63- img_start_locs : torch .Tensor ,
63+ img_start_locs_in_cache : torch .Tensor ,
6464):
6565 assert input_ids .dim () == 1
6666 assert out .dim () == 2
@@ -79,7 +79,7 @@ def add_deepstack_embs(
7979 out ,
8080 img_token_lens ,
8181 img_start_token_ids ,
82- img_start_locs ,
82+ img_start_locs_in_cache ,
8383 deepstack_embs .stride (0 ),
8484 deepstack_embs .stride (1 ),
8585 out .stride (0 ),
@@ -105,20 +105,17 @@ def apply_deepstack_features(
105105 if not infer_state .deepstack_features :
106106 return
107107
108- if layer_num >= len (infer_state .deepstack_features [0 ]):
109- return
108+ deepstack_num_layers = infer_state .cpu_embed_cache_tensor .shape [1 ] - 1
110109
111- per_img_deepstack_features = [
112- infer_state .deepstack_features [i ][layer_num ] for i in range (infer_state .img_token_lens .shape [0 ])
113- ]
114- all_deepstack_features = torch .cat (per_img_deepstack_features , dim = 0 )
110+ if layer_num >= deepstack_num_layers :
111+ return
115112
116113 add_deepstack_embs (
117114 out = input_embeddings ,
118115 input_ids = infer_state .input_ids ,
119- deepstack_embs = all_deepstack_features ,
116+ deepstack_embs = infer_state . cpu_embed_cache_tensor [:, layer_num + 1 , :] ,
120117 img_token_lens = infer_state .img_token_lens ,
121118 img_start_token_ids = infer_state .img_start_token_ids ,
122- img_start_locs = infer_state .img_start_locs ,
119+ img_start_locs_in_cache = infer_state .img_start_locs_in_cache ,
123120 )
124121 return
0 commit comments