@@ -46,7 +46,7 @@ def build_grid_gather_index_and_bucket_fast(H, W, pool_h, pool_w, seqlen):
4646 Gh = (H + pool_h - 1 ) // pool_h
4747 Gw = (W + pool_w - 1 ) // pool_w
4848
49- # 单帧
49+ # Single frame
5050 gather_single = []
5151 bucket_sizes_single = []
5252
@@ -90,23 +90,22 @@ def build_grid_gather_index_and_bucket_fast(H, W, pool_h, pool_w, seqlen):
9090
9191 @classmethod
9292 @torch .compiler .disable
93- def prepare_reorg_idx_and_bucket_offset (cls , seqlen , frame_h , frame_w ):
93+ def prepare_reorg_idx_and_bucket_offset (cls , seqlen , frame_h , frame_w , pool_h , pool_w , device ):
9494 if (seqlen , frame_h , frame_w ) in cls .reorg_idx_dict :
9595 return
96- pool_h , pool_w = (8 , 16 ) if frame_h < frame_w else (16 , 8 )
9796 reorg_idx , bucket_sizes , bucket_offsets = cls .build_grid_gather_index_and_bucket_fast (
9897 H = frame_h ,
9998 W = frame_w ,
10099 pool_h = pool_h ,
101100 pool_w = pool_w ,
102101 seqlen = seqlen ,
103102 )
104- reorg_idx = torch .tensor (reorg_idx , dtype = torch .long , device = "cuda" )
103+ reorg_idx = torch .tensor (reorg_idx , dtype = torch .long , device = device )
105104 restore_idx = torch .empty_like (reorg_idx )
106- restore_idx [reorg_idx ] = torch .arange (reorg_idx .numel (), device = reorg_idx . device )
105+ restore_idx [reorg_idx ] = torch .arange (reorg_idx .numel (), device = device )
107106 cls .reorg_idx_dict [(seqlen , frame_h , frame_w )] = reorg_idx
108107 cls .restore_idx_dict [(seqlen , frame_h , frame_w )] = restore_idx
109- cls .bucket_offsets_dict [(seqlen , frame_h , frame_w )] = torch .tensor (bucket_offsets , dtype = torch .int32 , device = "cuda" )
108+ cls .bucket_offsets_dict [(seqlen , frame_h , frame_w )] = torch .tensor (bucket_offsets , dtype = torch .int32 , device = device )
110109 logger .info (f"DraftAttnWeight: reorg_idx len: { len (reorg_idx )} " )
111110 logger .info (f"DraftAttnWeight: bucket_sizes: { bucket_sizes } " )
112111 logger .info (f"DraftAttnWeight: bucket_offsets: { bucket_offsets } " )
@@ -136,7 +135,7 @@ def sample_qk_attention_2d(
136135 q_vid = q_vid .permute (0 , 3 , 4 , 1 , 2 ).reshape (num_frames , H * D , frame_h , frame_w )
137136 k_vid = k_vid .permute (0 , 3 , 4 , 1 , 2 ).reshape (num_frames , H * D , frame_h , frame_w )
138137
139- # 3) 2D max ‐pool each frame (ceil_mode ensures we cover the edges):
138+ # 3) 2D avg ‐pool each frame (ceil_mode ensures we cover the edges):
140139 # → [num_frames, H*D, S_h, S_w]
141140 q_pooled = F .avg_pool2d (q_vid , kernel_size = (pool_h , pool_w ), stride = (pool_h , pool_w ), ceil_mode = True )
142141 k_pooled = F .avg_pool2d (k_vid , kernel_size = (pool_h , pool_w ), stride = (pool_h , pool_w ), ceil_mode = True )
@@ -186,7 +185,7 @@ def attention_percentile_mask_headwise(self, attn_map: torch.Tensor, r: float) -
186185 if k >= n :
187186 return torch .zeros_like (attn_map , dtype = torch .bool )
188187
189- # 每个 head 独立计算阈值
188+ # Calculate threshold for each head independently
190189 thresholds = torch .kthvalue (flat , k , dim = 1 ).values # [H]
191190 mask = attn_map >= thresholds [:, None , None ] # broadcasting
192191
@@ -221,14 +220,17 @@ def apply(
221220 block_size = frame_h * frame_w
222221 num_frames = seqlen // block_size
223222
223+ pool_h , pool_w = (8 , 16 ) if frame_h < frame_w else (16 , 8 )
224+
224225 self .prepare_reorg_idx_and_bucket_offset (
225226 seqlen = seqlen ,
226227 frame_h = frame_h ,
227228 frame_w = frame_w ,
229+ pool_h = pool_h ,
230+ pool_w = pool_w ,
231+ device = q .device ,
228232 )
229233
230- pool_h , pool_w = (8 , 16 ) if frame_h < frame_w else (16 , 8 )
231-
232234 attn = self .sample_qk_attention_2d (
233235 q ,
234236 k ,
@@ -258,7 +260,7 @@ def apply(
258260
259261 q_ranges = torch .stack ([q_start , q_end ], dim = 1 ).to (dtype = torch .int32 )
260262 k_ranges = torch .stack ([k_start , k_end ], dim = 1 ).to (dtype = torch .int32 )
261- attn_type_map = torch .zeros (len (q_ranges ), dtype = torch .int32 , device = "cuda" )
263+ attn_type_map = torch .zeros (len (q_ranges ), dtype = torch .int32 , device = q . device )
262264
263265 reorg_idx = self .reorg_idx_dict [(seqlen , frame_h , frame_w )]
264266 q = q [reorg_idx ]
0 commit comments