22
33from confection import SimpleFrozenDict
44
5- from ...compat import ExtraError , ValidationError , has_langchain , langchain
5+ from ...compat import ExtraError , ValidationError , has_langchain , langchain_community
66from ...registry import registry
77
88try :
9- from langchain import llms # noqa: F401
9+ from langchain_community import llms # noqa: F401
1010except (ImportError , AttributeError ):
1111 llms = None
1212
@@ -18,16 +18,17 @@ def __init__(
1818 api : str ,
1919 config : Dict [Any , Any ],
2020 query : Callable [
21- ["langchain.llms.BaseLLM" , Iterable [Iterable [Any ]]], Iterable [Iterable [Any ]]
21+ ["langchain_community.llms.BaseLLM" , Iterable [Iterable [Any ]]],
22+ Iterable [Iterable [Any ]],
2223 ],
2324 context_length : Optional [int ],
2425 ):
2526 """Initializes model instance for integration APIs.
2627 name (str): Name of LangChain model to instantiate.
2728 api (str): Name of class/API.
2829 config (Dict[Any, Any]): Config passed on to LangChain model.
29- query (Callable[[langchain .llms.BaseLLM, Iterable[Iterable[Any]]], Iterable[Iterable[Any]]]): Callable executing
30- LLM prompts when supplied with the model instance.
30+ query (Callable[[langchain_community .llms.BaseLLM, Iterable[Iterable[Any]]], Iterable[Iterable[Any]]]): Callable
31+ executing LLM prompts when supplied with the model instance.
3132 context_length (Optional[int]): Context length for this model. Only necessary for sharding. If no context
3233 length provided, prompts can't be sharded.
3334 """
@@ -39,7 +40,7 @@ def __init__(
3940 @classmethod
4041 def _init_langchain_model (
4142 cls , name : str , api : str , config : Dict [Any , Any ]
42- ) -> "langchain .llms.BaseLLM" :
43+ ) -> "langchain_community .llms.BaseLLM" :
4344 """Initializes langchain model. langchain expects a range of different model ID argument names, depending on the
4445 model class. There doesn't seem to be a clean way to determine those from the outset, we'll fail our way through
4546 them.
@@ -73,12 +74,13 @@ def _init_langchain_model(
7374 raise err
7475
7576 @staticmethod
76- def get_type_to_cls_dict () -> Dict [str , Type ["langchain .llms.BaseLLM" ]]:
77- """Returns langchain .llms.type_to_cls_dict.
78- RETURNS (Dict[str, Type[langchain .llms.BaseLLM]]): langchain .llms.type_to_cls_dict.
77+ def get_type_to_cls_dict () -> Dict [str , Type ["langchain_community .llms.BaseLLM" ]]:
78+ """Returns langchain_community .llms.type_to_cls_dict.
79+ RETURNS (Dict[str, Type[langchain_community .llms.BaseLLM]]): langchain_community .llms.type_to_cls_dict.
7980 """
8081 return {
81- llm_id : getattr (langchain .llms , llm_id ) for llm_id in langchain .llms .__all__
82+ llm_id : getattr (langchain_community .llms , llm_id )
83+ for llm_id in langchain_community .llms .__all__
8284 }
8385
8486 def __call__ (self , prompts : Iterable [Iterable [Any ]]) -> Iterable [Iterable [Any ]]:
@@ -90,15 +92,16 @@ def __call__(self, prompts: Iterable[Iterable[Any]]) -> Iterable[Iterable[Any]]:
9092
9193 @staticmethod
9294 def query_langchain (
93- model : "langchain .llms.BaseLLM" , prompts : Iterable [Iterable [Any ]]
95+ model : "langchain_community .llms.BaseLLM" , prompts : Iterable [Iterable [Any ]]
9496 ) -> Iterable [Iterable [Any ]]:
9597 """Query LangChain model naively.
96- model (langchain .llms.BaseLLM): LangChain model.
98+ model (langchain_community .llms.BaseLLM): LangChain model.
9799 prompts (Iterable[Iterable[Any]]): Prompts to execute.
98100 RETURNS (Iterable[Iterable[Any]]): LLM responses.
99101 """
100- assert callable (model )
101- return [[model (pr ) for pr in prompts_for_doc ] for prompts_for_doc in prompts ]
102+ return [
103+ [model .invoke (pr ) for pr in prompts_for_doc ] for prompts_for_doc in prompts
104+ ]
102105
103106 @staticmethod
104107 def _check_installation () -> None :
@@ -115,7 +118,7 @@ def langchain_model(
115118 name : str ,
116119 query : Optional [
117120 Callable [
118- ["langchain .llms.BaseLLM" , Iterable [Iterable [str ]]],
121+ ["langchain_community .llms.BaseLLM" , Iterable [Iterable [str ]]],
119122 Iterable [Iterable [str ]],
120123 ]
121124 ] = None ,
@@ -170,11 +173,12 @@ def register_models() -> None:
170173@registry .llm_queries ("spacy.CallLangChain.v1" )
171174def query_langchain () -> (
172175 Callable [
173- ["langchain.llms.BaseLLM" , Iterable [Iterable [Any ]]], Iterable [Iterable [Any ]]
176+ ["langchain_community.llms.BaseLLM" , Iterable [Iterable [Any ]]],
177+ Iterable [Iterable [Any ]],
174178 ]
175179):
176180 """Returns query Callable for LangChain.
177- RETURNS (Callable[["langchain .llms.BaseLLM", Iterable[Iterable[Any]]], Iterable[Iterable[Any]]]): Callable executing
178- simple prompts on the specified LangChain model.
181+ RETURNS (Callable[["langchain_community .llms.BaseLLM", Iterable[Iterable[Any]]], Iterable[Iterable[Any]]]): Callable
182+ executing simple prompts on the specified LangChain model.
179183 """
180184 return LangChain .query_langchain
0 commit comments