|
| 1 | +import torch |
| 2 | +import triton |
| 3 | +import triton.language as tl |
| 4 | +from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo |
| 5 | +from lightllm.models.qwen3_vl.triton_kernel.deepstack_multimodal_emb import add_deepstack_embs |
| 6 | + |
| 7 | + |
| 8 | +def test_deepstack_same_image_twice(): |
| 9 | + device = "cuda" |
| 10 | + |
| 11 | + # 1. 构造 input_ids,包含两段相同的 image token 范围 [100, 101, 102] |
| 12 | + input_ids = torch.tensor( |
| 13 | + [1, 100, 101, 102, 2, 100, 101, 102, 3], |
| 14 | + device=device, |
| 15 | + dtype=torch.long, |
| 16 | + ) |
| 17 | + seq_len = input_ids.shape[0] |
| 18 | + |
| 19 | + hidden_size = 4 |
| 20 | + token_len = 3 # 每张图 3 个 token |
| 21 | + |
| 22 | + # 2. 构造初始 embedding,全 0,方便看增量 |
| 23 | + input_embeddings = torch.zeros(seq_len, hidden_size, device=device, dtype=torch.float32) |
| 24 | + |
| 25 | + # 3. 构造 deepstack_embs(这一层的 deepstack) |
| 26 | + # 只有一张图片,所以 deepstack_embs 形状是 [token_len, hidden_size] |
| 27 | + # 每一行是 [1,1,1,1], [2,2,2,2], [3,3,3,3] |
| 28 | + deepstack_embs = torch.tensor( |
| 29 | + [ |
| 30 | + [1.0, 1.0, 1.0, 1.0], # 对应 token_id = 100 |
| 31 | + [2.0, 2.0, 2.0, 2.0], # 对应 token_id = 101 |
| 32 | + [3.0, 3.0, 3.0, 3.0], # 对应 token_id = 102 |
| 33 | + ], |
| 34 | + device=device, |
| 35 | + dtype=torch.float32, |
| 36 | + ) |
| 37 | + |
| 38 | + # 4. image 相关索引信息(与 multimodal_emb 一致的语义) |
| 39 | + img_start_token_ids = torch.tensor([100], device=device, dtype=torch.long) # 只有一个 image handle,从 100 开始 |
| 40 | + img_token_lens = torch.tensor([token_len], device=device, dtype=torch.long) |
| 41 | + img_start_locs = torch.tensor([0], device=device, dtype=torch.long) # deepstack_embs 从第 0 行开始是这张图的 |
| 42 | + |
| 43 | + # 5. 保存一份原始 embedding,方便求差 |
| 44 | + before = input_embeddings.clone() |
| 45 | + |
| 46 | + # 6. 调用 Triton 算子 |
| 47 | + add_deepstack_embs( |
| 48 | + out=input_embeddings, |
| 49 | + input_ids=input_ids, |
| 50 | + deepstack_embs=deepstack_embs, |
| 51 | + img_token_lens=img_token_lens, |
| 52 | + img_start_token_ids=img_start_token_ids, |
| 53 | + img_start_locs=img_start_locs, |
| 54 | + ) |
| 55 | + |
| 56 | + # 7. 看看相同图片两段上的增量 |
| 57 | + delta = input_embeddings - before |
| 58 | + |
| 59 | + print("input_ids:", input_ids) |
| 60 | + print("delta:\n", delta) |
| 61 | + |
| 62 | + # 第一次 image:位置 1,2,3 |
| 63 | + print("first image span delta:\n", delta[1:4]) |
| 64 | + # 第二次 image:位置 5,6,7 |
| 65 | + print("second image span delta:\n", delta[5:8]) |
| 66 | + |
| 67 | + # 8. 断言它们和预期一致 |
| 68 | + expected = deepstack_embs # [3, 4] |
| 69 | + |
| 70 | + assert torch.allclose(delta[1:4], expected), "first image span does not match expected deepstack" |
| 71 | + assert torch.allclose(delta[5:8], expected), "second image span does not match expected deepstack" |
| 72 | + |
| 73 | + # 其他位置应该仍然是 0 |
| 74 | + assert torch.all(delta[0] == 0) |
| 75 | + assert torch.all(delta[4] == 0) |
| 76 | + assert torch.all(delta[8] == 0) |
| 77 | + |
| 78 | + print("OK: same image appears twice, both spans get deepstack added correctly.") |
| 79 | + |
| 80 | + |
| 81 | +if __name__ == "__main__": |
| 82 | + test_deepstack_same_image_twice() |
0 commit comments