Skip to content

Commit 63348f1

Browse files
committed
lint
1 parent 22196d0 commit 63348f1

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

lightx2v/common/ops/attn/draft_attn.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def build_grid_gather_index_and_bucket_fast(H, W, pool_h, pool_w, seqlen):
4646
Gh = (H + pool_h - 1) // pool_h
4747
Gw = (W + pool_w - 1) // pool_w
4848

49-
# 单帧
49+
# Single frame
5050
gather_single = []
5151
bucket_sizes_single = []
5252

@@ -90,23 +90,22 @@ def build_grid_gather_index_and_bucket_fast(H, W, pool_h, pool_w, seqlen):
9090

9191
@classmethod
9292
@torch.compiler.disable
93-
def prepare_reorg_idx_and_bucket_offset(cls, seqlen, frame_h, frame_w):
93+
def prepare_reorg_idx_and_bucket_offset(cls, seqlen, frame_h, frame_w, pool_h, pool_w, device):
9494
if (seqlen, frame_h, frame_w) in cls.reorg_idx_dict:
9595
return
96-
pool_h, pool_w = (8, 16) if frame_h < frame_w else (16, 8)
9796
reorg_idx, bucket_sizes, bucket_offsets = cls.build_grid_gather_index_and_bucket_fast(
9897
H=frame_h,
9998
W=frame_w,
10099
pool_h=pool_h,
101100
pool_w=pool_w,
102101
seqlen=seqlen,
103102
)
104-
reorg_idx = torch.tensor(reorg_idx, dtype=torch.long, device="cuda")
103+
reorg_idx = torch.tensor(reorg_idx, dtype=torch.long, device=device)
105104
restore_idx = torch.empty_like(reorg_idx)
106-
restore_idx[reorg_idx] = torch.arange(reorg_idx.numel(), device=reorg_idx.device)
105+
restore_idx[reorg_idx] = torch.arange(reorg_idx.numel(), device=device)
107106
cls.reorg_idx_dict[(seqlen, frame_h, frame_w)] = reorg_idx
108107
cls.restore_idx_dict[(seqlen, frame_h, frame_w)] = restore_idx
109-
cls.bucket_offsets_dict[(seqlen, frame_h, frame_w)] = torch.tensor(bucket_offsets, dtype=torch.int32, device="cuda")
108+
cls.bucket_offsets_dict[(seqlen, frame_h, frame_w)] = torch.tensor(bucket_offsets, dtype=torch.int32, device=device)
110109
logger.info(f"DraftAttnWeight: reorg_idx len: {len(reorg_idx)}")
111110
logger.info(f"DraftAttnWeight: bucket_sizes: {bucket_sizes}")
112111
logger.info(f"DraftAttnWeight: bucket_offsets: {bucket_offsets}")
@@ -136,7 +135,7 @@ def sample_qk_attention_2d(
136135
q_vid = q_vid.permute(0, 3, 4, 1, 2).reshape(num_frames, H * D, frame_h, frame_w)
137136
k_vid = k_vid.permute(0, 3, 4, 1, 2).reshape(num_frames, H * D, frame_h, frame_w)
138137

139-
# 3) 2D max‐pool each frame (ceil_mode ensures we cover the edges):
138+
# 3) 2D avg‐pool each frame (ceil_mode ensures we cover the edges):
140139
# → [num_frames, H*D, S_h, S_w]
141140
q_pooled = F.avg_pool2d(q_vid, kernel_size=(pool_h, pool_w), stride=(pool_h, pool_w), ceil_mode=True)
142141
k_pooled = F.avg_pool2d(k_vid, kernel_size=(pool_h, pool_w), stride=(pool_h, pool_w), ceil_mode=True)
@@ -186,7 +185,7 @@ def attention_percentile_mask_headwise(self, attn_map: torch.Tensor, r: float) -
186185
if k >= n:
187186
return torch.zeros_like(attn_map, dtype=torch.bool)
188187

189-
# 每个 head 独立计算阈值
188+
# Calculate threshold for each head independently
190189
thresholds = torch.kthvalue(flat, k, dim=1).values # [H]
191190
mask = attn_map >= thresholds[:, None, None] # broadcasting
192191

@@ -221,14 +220,17 @@ def apply(
221220
block_size = frame_h * frame_w
222221
num_frames = seqlen // block_size
223222

223+
pool_h, pool_w = (8, 16) if frame_h < frame_w else (16, 8)
224+
224225
self.prepare_reorg_idx_and_bucket_offset(
225226
seqlen=seqlen,
226227
frame_h=frame_h,
227228
frame_w=frame_w,
229+
pool_h=pool_h,
230+
pool_w=pool_w,
231+
device=q.device,
228232
)
229233

230-
pool_h, pool_w = (8, 16) if frame_h < frame_w else (16, 8)
231-
232234
attn = self.sample_qk_attention_2d(
233235
q,
234236
k,
@@ -258,7 +260,7 @@ def apply(
258260

259261
q_ranges = torch.stack([q_start, q_end], dim=1).to(dtype=torch.int32)
260262
k_ranges = torch.stack([k_start, k_end], dim=1).to(dtype=torch.int32)
261-
attn_type_map = torch.zeros(len(q_ranges), dtype=torch.int32, device="cuda")
263+
attn_type_map = torch.zeros(len(q_ranges), dtype=torch.int32, device=q.device)
262264

263265
reorg_idx = self.reorg_idx_dict[(seqlen, frame_h, frame_w)]
264266
q = q[reorg_idx]

0 commit comments

Comments
 (0)