Skip to content

Commit 789ee8b

Browse files
authored
Adjust LangChain usage in line with langchain >= 0.1 (#433)
* Update langchain pin. Adjust langchain model invocation in line with >= 0.1 changes. * Import from langchain_community instead from langchain. * Ignore langchain_community deprecation warning.
1 parent 2e88594 commit 789ee8b

File tree

5 files changed

+28
-21
lines changed

5 files changed

+28
-21
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ filterwarnings = [
2626
"ignore:^.*`__get_validators__` is deprecated.*",
2727
"ignore:^.*The `construct` method is deprecated.*",
2828
"ignore:^.*Skipping device Apple Paravirtual device that does not support Metal 2.0.*",
29-
"ignore:^.*Pydantic V1 style `@validator` validators are deprecated.*"
29+
"ignore:^.*Pydantic V1 style `@validator` validators are deprecated.*",
30+
"ignore:^.*was deprecated in langchain-community.*"
3031
]
3132
markers = [
3233
"external: interacts with a (potentially cost-incurring) third-party API",

requirements-dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ mypy>=0.990,<1.1.0; platform_machine != "aarch64" and python_version >= "3.7"
77
black==22.3.0
88
types-requests==2.28.11.16
99
# Prompting libraries needed for testing
10-
langchain==0.0.331; python_version>="3.9"
10+
langchain>=0.1,<0.2; python_version>="3.9"
1111
# Workaround for LangChain bug: pin OpenAI version. To be removed after LangChain has been fixed - see
1212
# https://github.com/langchain-ai/langchain/issues/12967.
1313
openai>=0.27,<=0.28.1; python_version>="3.9"

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ spacy_misc =
4444

4545
[options.extras_require]
4646
langchain =
47-
langchain==0.0.335
47+
langchain>=0.1,<0.2
4848
transformers =
4949
torch>=1.13.1,<2.0
5050
transformers>=4.28.1,<5.0

spacy_llm/compat.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@
1919

2020
try:
2121
import langchain
22+
import langchain_community
2223

2324
has_langchain = True
2425
except (ImportError, AttributeError):
2526
langchain = None
27+
langchain_community = None
2628
has_langchain = False
2729

2830
try:

spacy_llm/models/langchain/model.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
from confection import SimpleFrozenDict
44

5-
from ...compat import ExtraError, ValidationError, has_langchain, langchain
5+
from ...compat import ExtraError, ValidationError, has_langchain, langchain_community
66
from ...registry import registry
77

88
try:
9-
from langchain import llms # noqa: F401
9+
from langchain_community import llms # noqa: F401
1010
except (ImportError, AttributeError):
1111
llms = None
1212

@@ -18,16 +18,17 @@ def __init__(
1818
api: str,
1919
config: Dict[Any, Any],
2020
query: Callable[
21-
["langchain.llms.BaseLLM", Iterable[Iterable[Any]]], Iterable[Iterable[Any]]
21+
["langchain_community.llms.BaseLLM", Iterable[Iterable[Any]]],
22+
Iterable[Iterable[Any]],
2223
],
2324
context_length: Optional[int],
2425
):
2526
"""Initializes model instance for integration APIs.
2627
name (str): Name of LangChain model to instantiate.
2728
api (str): Name of class/API.
2829
config (Dict[Any, Any]): Config passed on to LangChain model.
29-
query (Callable[[langchain.llms.BaseLLM, Iterable[Iterable[Any]]], Iterable[Iterable[Any]]]): Callable executing
30-
LLM prompts when supplied with the model instance.
30+
query (Callable[[langchain_community.llms.BaseLLM, Iterable[Iterable[Any]]], Iterable[Iterable[Any]]]): Callable
31+
executing LLM prompts when supplied with the model instance.
3132
context_length (Optional[int]): Context length for this model. Only necessary for sharding. If no context
3233
length provided, prompts can't be sharded.
3334
"""
@@ -39,7 +40,7 @@ def __init__(
3940
@classmethod
4041
def _init_langchain_model(
4142
cls, name: str, api: str, config: Dict[Any, Any]
42-
) -> "langchain.llms.BaseLLM":
43+
) -> "langchain_community.llms.BaseLLM":
4344
"""Initializes langchain model. langchain expects a range of different model ID argument names, depending on the
4445
model class. There doesn't seem to be a clean way to determine those from the outset, we'll fail our way through
4546
them.
@@ -73,12 +74,13 @@ def _init_langchain_model(
7374
raise err
7475

7576
@staticmethod
76-
def get_type_to_cls_dict() -> Dict[str, Type["langchain.llms.BaseLLM"]]:
77-
"""Returns langchain.llms.type_to_cls_dict.
78-
RETURNS (Dict[str, Type[langchain.llms.BaseLLM]]): langchain.llms.type_to_cls_dict.
77+
def get_type_to_cls_dict() -> Dict[str, Type["langchain_community.llms.BaseLLM"]]:
78+
"""Returns langchain_community.llms.type_to_cls_dict.
79+
RETURNS (Dict[str, Type[langchain_community.llms.BaseLLM]]): langchain_community.llms.type_to_cls_dict.
7980
"""
8081
return {
81-
llm_id: getattr(langchain.llms, llm_id) for llm_id in langchain.llms.__all__
82+
llm_id: getattr(langchain_community.llms, llm_id)
83+
for llm_id in langchain_community.llms.__all__
8284
}
8385

8486
def __call__(self, prompts: Iterable[Iterable[Any]]) -> Iterable[Iterable[Any]]:
@@ -90,15 +92,16 @@ def __call__(self, prompts: Iterable[Iterable[Any]]) -> Iterable[Iterable[Any]]:
9092

9193
@staticmethod
9294
def query_langchain(
93-
model: "langchain.llms.BaseLLM", prompts: Iterable[Iterable[Any]]
95+
model: "langchain_community.llms.BaseLLM", prompts: Iterable[Iterable[Any]]
9496
) -> Iterable[Iterable[Any]]:
9597
"""Query LangChain model naively.
96-
model (langchain.llms.BaseLLM): LangChain model.
98+
model (langchain_community.llms.BaseLLM): LangChain model.
9799
prompts (Iterable[Iterable[Any]]): Prompts to execute.
98100
RETURNS (Iterable[Iterable[Any]]): LLM responses.
99101
"""
100-
assert callable(model)
101-
return [[model(pr) for pr in prompts_for_doc] for prompts_for_doc in prompts]
102+
return [
103+
[model.invoke(pr) for pr in prompts_for_doc] for prompts_for_doc in prompts
104+
]
102105

103106
@staticmethod
104107
def _check_installation() -> None:
@@ -115,7 +118,7 @@ def langchain_model(
115118
name: str,
116119
query: Optional[
117120
Callable[
118-
["langchain.llms.BaseLLM", Iterable[Iterable[str]]],
121+
["langchain_community.llms.BaseLLM", Iterable[Iterable[str]]],
119122
Iterable[Iterable[str]],
120123
]
121124
] = None,
@@ -170,11 +173,12 @@ def register_models() -> None:
170173
@registry.llm_queries("spacy.CallLangChain.v1")
171174
def query_langchain() -> (
172175
Callable[
173-
["langchain.llms.BaseLLM", Iterable[Iterable[Any]]], Iterable[Iterable[Any]]
176+
["langchain_community.llms.BaseLLM", Iterable[Iterable[Any]]],
177+
Iterable[Iterable[Any]],
174178
]
175179
):
176180
"""Returns query Callable for LangChain.
177-
RETURNS (Callable[["langchain.llms.BaseLLM", Iterable[Iterable[Any]]], Iterable[Iterable[Any]]]): Callable executing
178-
simple prompts on the specified LangChain model.
181+
RETURNS (Callable[["langchain_community.llms.BaseLLM", Iterable[Iterable[Any]]], Iterable[Iterable[Any]]]): Callable
182+
executing simple prompts on the specified LangChain model.
179183
"""
180184
return LangChain.query_langchain

0 commit comments

Comments
 (0)