-
-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Continue implementation of experimental Redis vector search #2430
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
64627f9
300a17e
3766967
70e0786
86319b1
9d74c4b
87f8534
7814c51
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,7 +3,7 @@ | |
| from chatterbot.storage import StorageAdapter | ||
| from chatterbot.logic import LogicAdapter | ||
| from chatterbot.search import TextSearch, IndexedTextSearch | ||
| from chatterbot.tagging import PosLemmaTagger | ||
| from chatterbot.tagging import PosLemmaTagger, NoOpTagger | ||
| from chatterbot.conversation import Statement | ||
| from chatterbot import languages | ||
| from chatterbot import utils | ||
|
|
@@ -74,41 +74,62 @@ def __init__(self, name, stream=False, **kwargs): | |
|
|
||
| tagger_language = kwargs.get('tagger_language', languages.ENG) | ||
|
|
||
| try: | ||
| Tagger = kwargs.get('tagger', PosLemmaTagger) | ||
|
|
||
| # Allow instances to be provided for performance optimization | ||
| # (Example: a pre-loaded model in a tagger when unit testing) | ||
| if not isinstance(Tagger, type): | ||
| self.tagger = Tagger | ||
| else: | ||
| self.tagger = Tagger(language=tagger_language) | ||
| except IOError as io_error: | ||
| # Return a more helpful error message if possible | ||
| if "Can't find model" in str(io_error): | ||
| model_name = utils.get_model_for_language(tagger_language) | ||
| if hasattr(tagger_language, 'ENGLISH_NAME'): | ||
| language_name = tagger_language.ENGLISH_NAME | ||
| # Check if storage adapter has a preferred tagger | ||
| PreferredTagger = self.storage.get_preferred_tagger() | ||
|
|
||
| if PreferredTagger is not None: | ||
| # Storage adapter specifies its own tagger | ||
| self.tagger = PreferredTagger(language=tagger_language) | ||
| else: | ||
| # Use default or user-specified tagger | ||
| try: | ||
| Tagger = kwargs.get('tagger', PosLemmaTagger) | ||
|
|
||
| # Allow instances to be provided for performance optimization | ||
| # (Example: a pre-loaded model in a tagger when unit testing) | ||
| if not isinstance(Tagger, type): | ||
| self.tagger = Tagger | ||
| else: | ||
| language_name = tagger_language | ||
| raise self.ChatBotException( | ||
| 'Setup error:\n' | ||
| f'The Spacy model for "{language_name}" language is missing.\n' | ||
| 'Please install the model using the command:\n\n' | ||
| f'python -m spacy download {model_name}\n\n' | ||
| 'See https://spacy.io/usage/models for more information about available models.' | ||
| ) from io_error | ||
| else: | ||
| raise io_error | ||
| self.tagger = Tagger(language=tagger_language) | ||
| except IOError as io_error: | ||
| # Return a more helpful error message if possible | ||
| if "Can't find model" in str(io_error): | ||
| model_name = utils.get_model_for_language(tagger_language) | ||
| if hasattr(tagger_language, 'ENGLISH_NAME'): | ||
| language_name = tagger_language.ENGLISH_NAME | ||
| else: | ||
| language_name = tagger_language | ||
| raise self.ChatBotException( | ||
| 'Setup error:\n' | ||
| f'The Spacy model for "{language_name}" language is missing.\n' | ||
| 'Please install the model using the command:\n\n' | ||
| f'python -m spacy download {model_name}\n\n' | ||
| 'See https://spacy.io/usage/models for more information about available models.' | ||
| ) from io_error | ||
| else: | ||
| raise io_error | ||
|
|
||
| # Initialize search algorithms | ||
| from chatterbot.search import SemanticVectorSearch | ||
|
|
||
| primary_search_algorithm = IndexedTextSearch(self, **kwargs) | ||
| text_search_algorithm = TextSearch(self, **kwargs) | ||
| semantic_vector_search_algorithm = SemanticVectorSearch(self, **kwargs) | ||
|
|
||
| self.search_algorithms = { | ||
| primary_search_algorithm.name: primary_search_algorithm, | ||
| text_search_algorithm.name: text_search_algorithm | ||
| text_search_algorithm.name: text_search_algorithm, | ||
| semantic_vector_search_algorithm.name: semantic_vector_search_algorithm | ||
| } | ||
|
|
||
| # Check if storage adapter has a preferred search algorithm | ||
| preferred_search_algorithm = self.storage.get_preferred_search_algorithm() | ||
| if preferred_search_algorithm and preferred_search_algorithm in self.search_algorithms: | ||
| # Set as default for logic adapters that don't specify their own search algorithm | ||
| # This ensures BestMatch and other adapters use the optimal search method | ||
| self.logger.info(f'Storage adapter prefers search algorithm: {preferred_search_algorithm}') | ||
| kwargs.setdefault('search_algorithm_name', preferred_search_algorithm) | ||
|
|
||
| for adapter in logic_adapters: | ||
| utils.validate_adapter_class(adapter, LogicAdapter) | ||
| logic_adapter = utils.initialize_class(adapter, self, **kwargs) | ||
|
|
@@ -191,15 +212,22 @@ def get_response(self, statement: Union[Statement, str, dict] = None, **kwargs) | |
| input_statement.in_response_to = previous_statement.text | ||
|
|
||
| # Make sure the input statement has its search text saved | ||
|
|
||
| if not input_statement.search_text: | ||
| _search_text = self.tagger.get_text_index_string(input_statement.text) | ||
| input_statement.search_text = _search_text | ||
|
|
||
| if not input_statement.search_in_response_to and input_statement.in_response_to: | ||
| input_statement.search_in_response_to = self.tagger.get_text_index_string( | ||
| input_statement.in_response_to | ||
| ) | ||
| if isinstance(self.tagger, NoOpTagger): | ||
| # NoOpTagger returns text unchanged, so we can skip the tagging call | ||
| if not input_statement.search_text: | ||
| input_statement.search_text = input_statement.text | ||
| if not input_statement.search_in_response_to and input_statement.in_response_to: | ||
| input_statement.search_in_response_to = input_statement.in_response_to | ||
|
||
| else: | ||
| # Use tagger for POS-lemma indexing or other transformations | ||
| if not input_statement.search_text: | ||
| _search_text = self.tagger.get_text_index_string(input_statement.text) | ||
| input_statement.search_text = _search_text | ||
|
|
||
| if not input_statement.search_in_response_to and input_statement.in_response_to: | ||
| input_statement.search_in_response_to = self.tagger.get_text_index_string( | ||
| input_statement.in_response_to | ||
| ) | ||
|
|
||
| response = self.generate_response( | ||
| input_statement, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] There's a blank line after the import statement that creates inconsistent spacing. Consider removing this trailing blank line for cleaner code formatting.