diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py index aeffd3a67..65ec4354b 100644 --- a/lightllm/server/router/dynamic_prompt/radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/radix_cache.py @@ -82,13 +82,20 @@ def is_leaf(self): return len(self.children) == 0 -def match(key, seq): - i = 0 - for k, w in zip(key, seq): - if k != w: - break - i += 1 - return i +def match(t1: torch.Tensor, t2: torch.Tensor) -> int: + # Ensure same shape for comparison: flatten and get min length + t1_flat = t1.flatten() + t2_flat = t2.flatten() + min_len = min(t1_flat.size(0), t2_flat.size(0)) + + # Compare elements and find first mismatch + diff = t1_flat[:min_len] != t2_flat[:min_len] + mismatch_indices = torch.nonzero(diff) + + if mismatch_indices.numel() == 0: + return min_len # All matched up to min_len + else: + return mismatch_indices[0].item() class RadixCache: