-
Notifications
You must be signed in to change notification settings - Fork 255
fix: fix wrong param in SentenceChunker #370
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
|
Please pull MemTensor:dev branch and solve the conflicts, thank you. @Linorman |
Ok, I have already solved confilcts and merged to my main branch. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR updates the codebase to support both legacy and new API versions of the DynamicCache class from the transformers library. The new API uses a layers attribute with .keys and .values properties, while the legacy API uses key_cache and value_cache list attributes.
- Removes version-based branching (previously using
packaging.versionchecks) in favor of runtime attribute detection - Updates cache concatenation logic to handle both APIs through
hasattr()checks - Fixes parameter naming in
ChonkieSentenceChunkerinitialization
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| src/memos/memories/activation/kv.py | Replaced version-based API detection with runtime attribute checking; removed unused imports; updated _concat_caches to mutate first cache in-place for new API |
| tests/memories/activation/test_kv.py | Added compatibility layer in test helper make_filled_cache() and assertions to support both old and new DynamicCache APIs |
| src/memos/mem_os/utils/format_utils.py | Updated serialization functions to detect and handle both DynamicCache API versions when extracting layer counts, device info, dtype, and tensor shapes |
| src/memos/chunkers/sentence_chunker.py | Changed parameter name from tokenizer_or_token_counter to tokenizer to match the expected API |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| base = caches[0] | ||
| for layer in range(num_layers): | ||
| # gather all K and V for this layer | ||
| keys = [c.layers[layer].keys for c in caches] | ||
| vals = [c.layers[layer].values for c in caches] | ||
| # single concat per layer | ||
| merged.layers[layer].keys = torch.cat(keys, dim=-2) | ||
| merged.layers[layer].values = torch.cat(vals, dim=-2) | ||
|
|
||
| base.layers[layer].keys = torch.cat(keys, dim=-2) | ||
| base.layers[layer].values = torch.cat(vals, dim=-2) | ||
| return base |
Copilot
AI
Nov 4, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mutating the first cache object in-place (caches[0]) is problematic because it modifies the original cache that may still be referenced elsewhere. This could lead to unexpected side effects if the caller expects the original caches to remain unchanged. Consider creating a new DynamicCache() object and populating its layers similar to the legacy API path.
| t = getattr(lyr, "keys", None) | ||
| if t is None: | ||
| t = getattr(lyr, "values", None) | ||
| if t is not None: | ||
| device_str = str(t.device) | ||
| dtype_str = str(t.dtype) |
Copilot
AI
Nov 4, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The variable name t is unclear and could be improved for readability. Consider renaming it to tensor or sample_tensor to better convey its purpose as a tensor used to determine device and dtype information.
| t = getattr(lyr, "keys", None) | |
| if t is None: | |
| t = getattr(lyr, "values", None) | |
| if t is not None: | |
| device_str = str(t.device) | |
| dtype_str = str(t.dtype) | |
| tensor = getattr(lyr, "keys", None) | |
| if tensor is None: | |
| tensor = getattr(lyr, "values", None) | |
| if tensor is not None: | |
| device_str = str(tensor.device) | |
| dtype_str = str(tensor.dtype) |
Description
Summary: Just fix param name in SentenceChunker to avoid bugs in runtime.
Fix: #(issue)
Docs Issue/PR: (docs-issue-or-pr-link)
Reviewer: @(reviewer)
Checklist: