Skip to content

Commit 1902799

Browse files
author
sangchengmeng
committed
1210
1 parent f6c5d64 commit 1902799

File tree

1 file changed

+82
-0
lines changed

1 file changed

+82
-0
lines changed
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo
5+
from lightllm.models.qwen3_vl.triton_kernel.deepstack_multimodal_emb import add_deepstack_embs
6+
7+
8+
def test_deepstack_same_image_twice():
9+
device = "cuda"
10+
11+
# 1. 构造 input_ids,包含两段相同的 image token 范围 [100, 101, 102]
12+
input_ids = torch.tensor(
13+
[1, 100, 101, 102, 2, 100, 101, 102, 3],
14+
device=device,
15+
dtype=torch.long,
16+
)
17+
seq_len = input_ids.shape[0]
18+
19+
hidden_size = 4
20+
token_len = 3 # 每张图 3 个 token
21+
22+
# 2. 构造初始 embedding,全 0,方便看增量
23+
input_embeddings = torch.zeros(seq_len, hidden_size, device=device, dtype=torch.float32)
24+
25+
# 3. 构造 deepstack_embs(这一层的 deepstack)
26+
# 只有一张图片,所以 deepstack_embs 形状是 [token_len, hidden_size]
27+
# 每一行是 [1,1,1,1], [2,2,2,2], [3,3,3,3]
28+
deepstack_embs = torch.tensor(
29+
[
30+
[1.0, 1.0, 1.0, 1.0], # 对应 token_id = 100
31+
[2.0, 2.0, 2.0, 2.0], # 对应 token_id = 101
32+
[3.0, 3.0, 3.0, 3.0], # 对应 token_id = 102
33+
],
34+
device=device,
35+
dtype=torch.float32,
36+
)
37+
38+
# 4. image 相关索引信息(与 multimodal_emb 一致的语义)
39+
img_start_token_ids = torch.tensor([100], device=device, dtype=torch.long) # 只有一个 image handle,从 100 开始
40+
img_token_lens = torch.tensor([token_len], device=device, dtype=torch.long)
41+
img_start_locs = torch.tensor([0], device=device, dtype=torch.long) # deepstack_embs 从第 0 行开始是这张图的
42+
43+
# 5. 保存一份原始 embedding,方便求差
44+
before = input_embeddings.clone()
45+
46+
# 6. 调用 Triton 算子
47+
add_deepstack_embs(
48+
out=input_embeddings,
49+
input_ids=input_ids,
50+
deepstack_embs=deepstack_embs,
51+
img_token_lens=img_token_lens,
52+
img_start_token_ids=img_start_token_ids,
53+
img_start_locs=img_start_locs,
54+
)
55+
56+
# 7. 看看相同图片两段上的增量
57+
delta = input_embeddings - before
58+
59+
print("input_ids:", input_ids)
60+
print("delta:\n", delta)
61+
62+
# 第一次 image:位置 1,2,3
63+
print("first image span delta:\n", delta[1:4])
64+
# 第二次 image:位置 5,6,7
65+
print("second image span delta:\n", delta[5:8])
66+
67+
# 8. 断言它们和预期一致
68+
expected = deepstack_embs # [3, 4]
69+
70+
assert torch.allclose(delta[1:4], expected), "first image span does not match expected deepstack"
71+
assert torch.allclose(delta[5:8], expected), "second image span does not match expected deepstack"
72+
73+
# 其他位置应该仍然是 0
74+
assert torch.all(delta[0] == 0)
75+
assert torch.all(delta[4] == 0)
76+
assert torch.all(delta[8] == 0)
77+
78+
print("OK: same image appears twice, both spans get deepstack added correctly.")
79+
80+
81+
if __name__ == "__main__":
82+
test_deepstack_same_image_twice()

0 commit comments

Comments
 (0)