Skip to content

Commit 8b8d90b

Browse files
authored
1 parent 095f4a7 commit 8b8d90b

31 files changed

+181
-102
lines changed

libs/langchain/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ select = [
167167
"PGH", # pygrep-hooks
168168
"PIE", # flake8-pie
169169
"PERF", # flake8-perf
170+
"PT", # flake8-pytest-style
170171
"PTH", # flake8-use-pathlib
171172
"PYI", # flake8-pyi
172173
"Q", # flake8-quotes

libs/langchain/tests/integration_tests/embeddings/test_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010

1111
@pytest.mark.parametrize(
12-
"provider, model",
12+
("provider", "model"),
1313
[
1414
("openai", "text-embedding-3-large"),
1515
("google_vertexai", "text-embedding-gecko@003"),
Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
1-
import unittest
2-
31
from langchain.agents.agent_types import AgentType
42
from langchain.agents.types import AGENT_TO_CLASS
53

64

7-
class TestTypes(unittest.TestCase):
8-
def test_confirm_full_coverage(self) -> None:
9-
self.assertEqual(list(AgentType), list(AGENT_TO_CLASS.keys()))
5+
def test_confirm_full_coverage() -> None:
6+
assert list(AgentType) == list(AGENT_TO_CLASS.keys())

libs/langchain/tests/unit_tests/chains/query_constructor/test_parser.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
DEFAULT_PARSER = get_parser()
1717

1818

19-
@pytest.mark.parametrize("x", ("", "foo", 'foo("bar", "baz")'))
19+
@pytest.mark.parametrize("x", ["", "foo", 'foo("bar", "baz")'])
2020
def test_parse_invalid_grammar(x: str) -> None:
2121
with pytest.raises((ValueError, lark.exceptions.UnexpectedToken)):
2222
DEFAULT_PARSER.parse(x)
@@ -71,13 +71,13 @@ def test_parse_nested_operation() -> None:
7171

7272
def test_parse_disallowed_comparator() -> None:
7373
parser = get_parser(allowed_comparators=[Comparator.EQ])
74-
with pytest.raises(ValueError):
74+
with pytest.raises(ValueError, match="Received disallowed comparator gt."):
7575
parser.parse('gt("a", 2)')
7676

7777

7878
def test_parse_disallowed_operator() -> None:
7979
parser = get_parser(allowed_operators=[Operator.AND])
80-
with pytest.raises(ValueError):
80+
with pytest.raises(ValueError, match="Received disallowed operator not."):
8181
parser.parse('not(gt("a", 2))')
8282

8383

@@ -87,53 +87,53 @@ def _test_parse_value(x: Any) -> None:
8787
assert actual == x
8888

8989

90-
@pytest.mark.parametrize("x", (-1, 0, 1_000_000))
90+
@pytest.mark.parametrize("x", [-1, 0, 1_000_000])
9191
def test_parse_int_value(x: int) -> None:
9292
_test_parse_value(x)
9393

9494

95-
@pytest.mark.parametrize("x", (-1.001, 0.00000002, 1_234_567.6543210))
95+
@pytest.mark.parametrize("x", [-1.001, 0.00000002, 1_234_567.6543210])
9696
def test_parse_float_value(x: float) -> None:
9797
_test_parse_value(x)
9898

9999

100-
@pytest.mark.parametrize("x", ([], [1, "b", "true"]))
100+
@pytest.mark.parametrize("x", [[], [1, "b", "true"]])
101101
def test_parse_list_value(x: list) -> None:
102102
_test_parse_value(x)
103103

104104

105-
@pytest.mark.parametrize("x", ('""', '" "', '"foo"', "'foo'"))
105+
@pytest.mark.parametrize("x", ['""', '" "', '"foo"', "'foo'"])
106106
def test_parse_string_value(x: str) -> None:
107107
parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("x", {x})'))
108108
actual = parsed.value
109109
assert actual == x[1:-1]
110110

111111

112-
@pytest.mark.parametrize("x", ("true", "True", "TRUE", "false", "False", "FALSE"))
112+
@pytest.mark.parametrize("x", ["true", "True", "TRUE", "false", "False", "FALSE"])
113113
def test_parse_bool_value(x: str) -> None:
114114
parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("x", {x})'))
115115
actual = parsed.value
116116
expected = x.lower() == "true"
117117
assert actual == expected
118118

119119

120-
@pytest.mark.parametrize("op", ("and", "or"))
121-
@pytest.mark.parametrize("arg", ('eq("foo", 2)', 'and(eq("foo", 2), lte("bar", 1.1))'))
120+
@pytest.mark.parametrize("op", ["and", "or"])
121+
@pytest.mark.parametrize("arg", ['eq("foo", 2)', 'and(eq("foo", 2), lte("bar", 1.1))'])
122122
def test_parser_unpack_single_arg_operation(op: str, arg: str) -> None:
123123
expected = DEFAULT_PARSER.parse(arg)
124124
actual = DEFAULT_PARSER.parse(f"{op}({arg})")
125125
assert expected == actual
126126

127127

128-
@pytest.mark.parametrize("x", ('"2022-10-20"', "'2022-10-20'", "2022-10-20"))
128+
@pytest.mark.parametrize("x", ['"2022-10-20"', "'2022-10-20'", "2022-10-20"])
129129
def test_parse_date_value(x: str) -> None:
130130
parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("x", {x})'))
131131
actual = parsed.value["date"]
132132
assert actual == x.strip("'\"")
133133

134134

135135
@pytest.mark.parametrize(
136-
"x, expected",
136+
("x", "expected"),
137137
[
138138
(
139139
'"2021-01-01T00:00:00"',

libs/langchain/tests/unit_tests/chains/question_answering/test_map_rerank_prompt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
SCORE_WITH_EXPLANATION = "foo bar answer.\nScore: 80 (fully answers the question, but could provide more detail on the specific error message)" # noqa: E501
99

1010

11-
@pytest.mark.parametrize("answer", (GOOD_SCORE, SCORE_WITH_EXPLANATION))
11+
@pytest.mark.parametrize("answer", [GOOD_SCORE, SCORE_WITH_EXPLANATION])
1212
def test_parse_scores(answer: str) -> None:
1313
result = output_parser.parse(answer)
1414

libs/langchain/tests/unit_tests/chains/test_base.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,14 @@ def _call(
6565
def test_bad_inputs() -> None:
6666
"""Test errors are raised if input keys are not found."""
6767
chain = FakeChain()
68-
with pytest.raises(ValueError):
68+
with pytest.raises(ValueError, match="Missing some input keys: {'foo'}"):
6969
chain({"foobar": "baz"})
7070

7171

7272
def test_bad_outputs() -> None:
7373
"""Test errors are raised if outputs keys are not found."""
7474
chain = FakeChain(be_correct=False)
75-
with pytest.raises(ValueError):
75+
with pytest.raises(ValueError, match="Missing some output keys: {'bar'}"):
7676
chain({"foo": "baz"})
7777

7878

@@ -102,7 +102,7 @@ def test_single_input_correct() -> None:
102102
def test_single_input_error() -> None:
103103
"""Test passing single input errors as expected."""
104104
chain = FakeChain(the_input_keys=["foo", "bar"])
105-
with pytest.raises(ValueError):
105+
with pytest.raises(ValueError, match="Missing some input keys:"):
106106
chain("bar")
107107

108108

@@ -116,7 +116,9 @@ def test_run_single_arg() -> None:
116116
def test_run_multiple_args_error() -> None:
117117
"""Test run method with multiple args errors as expected."""
118118
chain = FakeChain()
119-
with pytest.raises(ValueError):
119+
with pytest.raises(
120+
ValueError, match="`run` supports only one positional argument."
121+
):
120122
chain.run("bar", "foo")
121123

122124

@@ -130,21 +132,28 @@ def test_run_kwargs() -> None:
130132
def test_run_kwargs_error() -> None:
131133
"""Test run method with kwargs errors as expected."""
132134
chain = FakeChain(the_input_keys=["foo", "bar"])
133-
with pytest.raises(ValueError):
135+
with pytest.raises(ValueError, match="Missing some input keys: {'bar'}"):
134136
chain.run(foo="bar", baz="foo")
135137

136138

137139
def test_run_args_and_kwargs_error() -> None:
138140
"""Test run method with args and kwargs."""
139141
chain = FakeChain(the_input_keys=["foo", "bar"])
140-
with pytest.raises(ValueError):
142+
with pytest.raises(
143+
ValueError,
144+
match="`run` supported with either positional arguments "
145+
"or keyword arguments but not both.",
146+
):
141147
chain.run("bar", foo="bar")
142148

143149

144150
def test_multiple_output_keys_error() -> None:
145151
"""Test run with multiple output keys errors as expected."""
146152
chain = FakeChain(the_output_keys=["foo", "bar"])
147-
with pytest.raises(ValueError):
153+
with pytest.raises(
154+
ValueError,
155+
match="`run` not supported when there is not exactly one output key.",
156+
):
148157
chain.run("bar")
149158

150159

@@ -175,7 +184,7 @@ def test_run_with_callback_and_input_error() -> None:
175184
callbacks=[handler],
176185
)
177186

178-
with pytest.raises(ValueError):
187+
with pytest.raises(ValueError, match="Missing some input keys: {'foo'}"):
179188
chain({"bar": "foo"})
180189

181190
assert handler.starts == 1
@@ -222,7 +231,7 @@ def test_run_with_callback_and_output_error() -> None:
222231
callbacks=[handler],
223232
)
224233

225-
with pytest.raises(ValueError):
234+
with pytest.raises(ValueError, match="Missing some output keys: {'foo'}"):
226235
chain("foo")
227236

228237
assert handler.starts == 1

libs/langchain/tests/unit_tests/chains/test_combine_documents.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Test functionality related to combining documents."""
22

3+
import re
34
from typing import Any
45

56
import pytest
@@ -30,7 +31,9 @@ def test_multiple_input_keys() -> None:
3031
def test__split_list_long_single_doc() -> None:
3132
"""Test splitting of a long single doc."""
3233
docs = [Document(page_content="foo" * 100)]
33-
with pytest.raises(ValueError):
34+
with pytest.raises(
35+
ValueError, match="A single document was longer than the context length"
36+
):
3437
split_list_of_docs(docs, _fake_docs_len_func, 100)
3538

3639

@@ -140,7 +143,17 @@ async def test_format_doc_missing_metadata() -> None:
140143
input_variables=["page_content", "bar"],
141144
template="{page_content}, {bar}",
142145
)
143-
with pytest.raises(ValueError):
146+
with pytest.raises(
147+
ValueError,
148+
match=re.escape(
149+
"Document prompt requires documents to have metadata variables: ['bar']."
150+
),
151+
):
144152
format_document(doc, prompt)
145-
with pytest.raises(ValueError):
153+
with pytest.raises(
154+
ValueError,
155+
match=re.escape(
156+
"Document prompt requires documents to have metadata variables: ['bar']."
157+
),
158+
):
146159
await aformat_document(doc, prompt)

libs/langchain/tests/unit_tests/chains/test_conversation.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Test conversation chain and memory."""
22

3+
import re
34
from typing import Any, Optional
45

56
import pytest
@@ -76,7 +77,9 @@ def test_conversation_chain_errors_bad_prompt() -> None:
7677
"""Test that conversation chain raise error with bad prompt."""
7778
llm = FakeLLM()
7879
prompt = PromptTemplate(input_variables=[], template="nothing here")
79-
with pytest.raises(ValueError):
80+
with pytest.raises(
81+
ValueError, match="Value error, Got unexpected prompt input variables."
82+
):
8083
ConversationChain(llm=llm, prompt=prompt)
8184

8285

@@ -85,7 +88,12 @@ def test_conversation_chain_errors_bad_variable() -> None:
8588
llm = FakeLLM()
8689
prompt = PromptTemplate(input_variables=["foo"], template="{foo}")
8790
memory = ConversationBufferMemory(memory_key="foo")
88-
with pytest.raises(ValueError):
91+
with pytest.raises(
92+
ValueError,
93+
match=re.escape(
94+
"Value error, The input key foo was also found in the memory keys (['foo'])"
95+
),
96+
):
8997
ConversationChain(llm=llm, prompt=prompt, memory=memory, input_key="foo")
9098

9199

@@ -106,18 +114,18 @@ def test_conversation_memory(memory: BaseMemory) -> None:
106114
memory.save_context(good_inputs, good_outputs)
107115
# This is a bad input because there are two variables that aren't the same as baz.
108116
bad_inputs = {"foo": "bar", "foo1": "bar"}
109-
with pytest.raises(ValueError):
117+
with pytest.raises(ValueError, match="One input key expected"):
110118
memory.save_context(bad_inputs, good_outputs)
111119
# This is a bad input because the only variable is the same as baz.
112120
bad_inputs = {"baz": "bar"}
113-
with pytest.raises(ValueError):
121+
with pytest.raises(ValueError, match=re.escape("One input key expected got []")):
114122
memory.save_context(bad_inputs, good_outputs)
115123
# This is a bad output because it is empty.
116-
with pytest.raises(ValueError):
124+
with pytest.raises(ValueError, match="Got multiple output keys"):
117125
memory.save_context(good_inputs, {})
118126
# This is a bad output because there are two keys.
119127
bad_outputs = {"foo": "bar", "foo1": "bar"}
120-
with pytest.raises(ValueError):
128+
with pytest.raises(ValueError, match="Got multiple output keys"):
121129
memory.save_context(good_inputs, bad_outputs)
122130

123131

libs/langchain/tests/unit_tests/chains/test_llm_math.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,5 @@ def test_complex_question(fake_llm_math_chain: LLMMathChain) -> None:
3939
@pytest.mark.requires("numexpr")
4040
def test_error(fake_llm_math_chain: LLMMathChain) -> None:
4141
"""Test question that raises error."""
42-
with pytest.raises(ValueError):
42+
with pytest.raises(ValueError, match="unknown format from LLM: foo"):
4343
fake_llm_math_chain.run("foo")

libs/langchain/tests/unit_tests/chains/test_qa_with_sources.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66

77
@pytest.mark.parametrize(
8-
"text,answer,sources",
8+
("text", "answer", "sources"),
99
[
1010
(
1111
"This Agreement is governed by English law.\nSOURCES: 28-pl",

0 commit comments

Comments
 (0)