Conversation
|
LGTM, you can solve the ruff issue and then we will merge :) I need to fix the CI so it can run the test when you make PR |
There was a problem hiding this comment.
Pull Request Overview
This PR refactors the search functionality to accept queries as a vector of individual tensors instead of a single batched 3D tensor. This allows for more flexible handling of variable-length query sequences without requiring padding at the batch level.
Key changes:
- Changed query input format from a single 3D tensor to a vector of individual tensors
- Removed tensor batching/padding logic in favor of list slicing for query distribution
- Added a
cleanup_embeddingshelper function to normalize query embeddings
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| rust/search/search.rs | Updated search_many to accept Vec<Tensor> instead of a single tensor; removed tensor shape validation; removed unused imports bail and IndexOp |
| rust/lib.rs | Updated load_and_search to accept Vec<PyTensor> and convert each tensor individually to Kind::Half |
| python/fast_plaid/search/fast_plaid.py | Added cleanup_embeddings helper; replaced tensor operations (chunk, split, pad_sequence) with list slicing; updated type hints to list[torch.Tensor] |
Comments suppressed due to low confidence (1)
rust/search/search.rs:157
- The documentation is outdated. The parameter is now
&Vec<Tensor>where each tensor represents an individual query, not a single 3D tensor.
/// * `queries` - A 3D tensor of query embeddings with shape `[num_queries, tokens_per_query, dim]`.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| dim=0, | ||
| ) | ||
| queries_embeddings_splits = [ | ||
| queries_embeddings[i:i + num_cpus] for i in range(0, num_queries, num_cpus) |
There was a problem hiding this comment.
The list comprehension creates chunks of size num_cpus starting at increments of num_cpus, which is incorrect. The step should match the chunk size to avoid overlap. Use: queries_embeddings[i*num_cpus:(i+1)*num_cpus] for i in range((num_queries + num_cpus - 1) // num_cpus)] or similar logic to properly partition the list.
| queries_embeddings[i:i + num_cpus] for i in range(0, num_queries, num_cpus) | |
| queries_embeddings[i*num_cpus:(i+1)*num_cpus] for i in range((num_queries + num_cpus - 1) // num_cpus) |
There was a problem hiding this comment.
The list comprehension seems correct to me. Don't know why Copilot thinks it's wrong
| queries_embeddings_splits = [ | ||
| queries_embeddings[i:i + len(self.devices)] for i in range(0, num_queries, len(self.devices)) | ||
| ] |
There was a problem hiding this comment.
The list comprehension creates chunks of size len(self.devices) starting at increments of len(self.devices), which is incorrect. The step should match the chunk size to avoid overlap. Use proper chunking logic to partition the list without overlap.
There was a problem hiding this comment.
Same here. I'm unsure what Copilot thinks is wrong
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This PR allows variable-length queries for searching