Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions lightllm/server/router/dynamic_prompt/radix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,20 @@ def is_leaf(self):
return len(self.children) == 0


def match(key, seq):
i = 0
for k, w in zip(key, seq):
if k != w:
break
i += 1
return i
def match(t1: torch.Tensor, t2: torch.Tensor) -> int:
# Ensure same shape for comparison: flatten and get min length
t1_flat = t1.flatten()
t2_flat = t2.flatten()
min_len = min(t1_flat.size(0), t2_flat.size(0))

# Compare elements and find first mismatch
diff = t1_flat[:min_len] != t2_flat[:min_len]
mismatch_indices = torch.nonzero(diff)

if mismatch_indices.numel() == 0:
return min_len # All matched up to min_len
else:
return mismatch_indices[0].item()
Comment on lines +85 to +98
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider using torch.argmax instead of torch.nonzero for better efficiency. torch.nonzero creates a tensor of all indices where the condition is true, which can be memory-intensive. torch.argmax finds the index of the first True value without allocating memory for all mismatch indices.

Suggested change
def match(t1: torch.Tensor, t2: torch.Tensor) -> int:
# Ensure same shape for comparison: flatten and get min length
t1_flat = t1.flatten()
t2_flat = t2.flatten()
min_len = min(t1_flat.size(0), t2_flat.size(0))
# Compare elements and find first mismatch
diff = t1_flat[:min_len] != t2_flat[:min_len]
mismatch_indices = torch.nonzero(diff)
if mismatch_indices.numel() == 0:
return min_len # All matched up to min_len
else:
return mismatch_indices[0].item()
def match(t1: torch.Tensor, t2: torch.Tensor) -> int:
t1_flat = t1.flatten()
t2_flat = t2.flatten()
min_len = min(t1_flat.size(0), t2_flat.size(0))
if min_len == 0:
return 0
diff = t1_flat[:min_len] != t2_flat[:min_len]
if not torch.any(diff):
return min_len
return torch.argmax(diff.byte()).item()



class RadixCache:
Expand Down