@@ -17,32 +17,32 @@ def _deepstack_add_kernel(
1717 stride_out_s ,
1818 stride_out_d ,
1919 hidden_size ,
20- BLOCK_HIDDEN_DIM : tl .constexpr ,
20+ BLOCK_DIM : tl .constexpr ,
2121):
2222 seq_index = tl .program_id (0 ).to (tl .int64 )
2323 img_handle_id = tl .program_id (1 )
2424
2525 token_id = tl .load (input_ids + seq_index )
26- off_d = tl .arange (0 , BLOCK_HIDDEN_DIM )
26+ off_d = tl .arange (0 , BLOCK_DIM )
2727
2828 img_start_token_id = tl .load (
29- Img_start_token_ids + img_handle_id - 1 ,
30- mask = img_handle_id >= 1 ,
29+ Img_start_token_ids + img_handle_id ,
30+ mask = img_handle_id >= 0 ,
3131 other = 0 ,
3232 )
3333 img_start_loc = tl .load (
34- Img_start_locs + img_handle_id - 1 ,
35- mask = img_handle_id >= 1 ,
34+ Img_start_locs + img_handle_id ,
35+ mask = img_handle_id >= 0 ,
3636 other = 0 ,
3737 )
3838 img_token_len = tl .load (
39- Img_token_lens + img_handle_id - 1 ,
40- mask = img_handle_id >= 1 ,
39+ Img_token_lens + img_handle_id ,
40+ mask = img_handle_id >= 0 ,
4141 other = 0 ,
4242 )
4343
4444 # 判断当前 token 是否属于这个 image
45- cond = (img_handle_id != 0 ) & ( token_id >= img_start_token_id ) & (token_id < img_start_token_id + img_token_len )
45+ cond = (token_id >= img_start_token_id ) & (token_id < img_start_token_id + img_token_len )
4646
4747 for _ in range (0 , tl .where (cond , 1 , 0 ), 1 ):
4848 token_offset = token_id - img_start_token_id
@@ -85,8 +85,8 @@ def add_deepstack_embs(
8585 hidden = out .shape [1 ]
8686 BLOCK = triton .next_power_of_2 (hidden )
8787
88- grid = (total_len , img_token_lens .shape [0 ] + 1 )
89- num_warps = 1
88+ grid = (total_len , img_token_lens .shape [0 ])
89+ num_warps = 4
9090
9191 _deepstack_add_kernel [grid ](
9292 input_ids ,
@@ -100,7 +100,7 @@ def add_deepstack_embs(
100100 out .stride (0 ),
101101 out .stride (1 ),
102102 hidden_size = hidden ,
103- BLOCK_HIDDEN_DIM = BLOCK ,
103+ BLOCK_DIM = BLOCK ,
104104 num_warps = num_warps ,
105105 num_stages = 1 ,
106106 )
@@ -117,7 +117,6 @@ def clear_deepstack_state(
117117 infer_state .img_start_token_ids = []
118118 infer_state .img_token_lens = None
119119 infer_state .img_start_locs = None
120- infer_state .image_num_need_deepstack = 0
121120 infer_state .deepstack_features = []
122121 return
123122
@@ -143,15 +142,14 @@ def apply_deepstack_features(
143142
144143 input_ids = infer_state .input_ids
145144 device = input_embeddings .device
146- dtype = input_embeddings .dtype
147145
148- if infer_state .image_num_need_deepstack == 0 :
146+ if infer_state .img_token_lens . shape [ 0 ] == 0 :
149147 clear_deepstack_state (layer_num , infer_state )
150148 return
151149
152150 per_img_deepstack_features = [
153- infer_state .deepstack_features [i ][layer_num ].to (device = device , dtype = dtype , non_blocking = True )
154- for i in range (infer_state .image_num_need_deepstack )
151+ infer_state .deepstack_features [i ][layer_num ].to (device = device , non_blocking = True )
152+ for i in range (infer_state .img_token_lens . shape [ 0 ] )
155153 ]
156154 all_deepstack_features = torch .cat (per_img_deepstack_features , dim = 0 )
157155
0 commit comments