Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from langchain_core.callbacks import Callbacks
from langchain_core.documents import Document, BaseDocumentCompressor
from transformers import AutoModelForSequenceClassification, AutoTokenizer

from models_provider.base_model_provider import MaxKBBaseModel

Expand All @@ -25,6 +24,7 @@ class LocalReranker(MaxKBBaseModel, BaseDocumentCompressor):

def __init__(self, model_name, cache_dir=None, **model_kwargs):
super().__init__()
from transformers import AutoModelForSequenceClassification, AutoTokenizer
self.model = model_name
self.cache_dir = cache_dir
self.model_kwargs = model_kwargs
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The provided code checks out with no immediate errors or issues related to syntax, logic, or imports. The initialization of AutoModelForSequenceClassification and AutoTokenizer is correctly placed after the necessary libraries have been imported.

Optimization Suggestion: Ensure that the paths used in file operations (like reading data) are correct relative to where the script is being run, as this can lead to FileNotFoundError.

Additionally, consider handling exceptions such as when attempting to load models to avoid silent failures:

from transformers import AutoModelForSequenceClassification, AutoTokenizer

class LocalReranker(MaxKBBaseModel, BaseDocumentCompressor):
    __name__ = "LocalReranker"

    def __init__(self, model_name : str, cache_dir: Optional[str] = None, **model_kwargs):
        super().__init__()
        
        try:
            from transformers import AutoModelForSequenceClassification, AutoTokenizer
        except ImportError as e:
            raise SystemError(f"Failed to import required transformers library(s): {str(e)}")

        self.model_name = model_name
        self.cache_dir = cache_dir if cache_dir else "/path/to/default/cache/dir"
        self.model_kwargs = model_kwargs
        
        # Initialize or load the model here

This modification adds basic error checking for missing dependencies and specifies a default path for caching models. Adjust the path according to how you set up your environment.

Expand Down
Loading