Skip to content

Commit 6d08690

Browse files
committed
applied the fix + the windows permission issue, you can ignore if you don't want it but ran all of the tests and passed
1 parent acd1aa8 commit 6d08690

File tree

4 files changed

+140
-16
lines changed

4 files changed

+140
-16
lines changed

libs/core/langchain_core/messages/utils.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1440,11 +1440,10 @@ def _first_max_tokens(
14401440
# When all messages fit, only apply end_on filtering if needed
14411441
if end_on:
14421442
for _ in range(len(messages)):
1443-
if not _is_message_type(messages[-1], end_on):
1444-
messages.pop()
1445-
else:
1443+
if not messages or _is_message_type(messages[-1], end_on):
14461444
break
1447-
return messages
1445+
messages.pop()
1446+
return _remove_orphaned_tool_messages(messages)
14481447

14491448
# Use binary search to find the maximum number of messages within token limit
14501449
left, right = 0, len(messages)
@@ -1535,7 +1534,7 @@ def _first_max_tokens(
15351534
else:
15361535
break
15371536

1538-
return messages[:idx]
1537+
return _remove_orphaned_tool_messages(messages[:idx])
15391538

15401539

15411540
def _last_max_tokens(
@@ -1594,7 +1593,42 @@ def _last_max_tokens(
15941593
if system_message:
15951594
result = [system_message, *result]
15961595

1597-
return result
1596+
return _remove_orphaned_tool_messages(result)
1597+
1598+
1599+
def _remove_orphaned_tool_messages(
1600+
messages: Sequence[BaseMessage],
1601+
) -> list[BaseMessage]:
1602+
"""Drop tool messages whose corresponding tool calls are absent."""
1603+
if not messages:
1604+
return []
1605+
1606+
valid_tool_call_ids: set[str] = set()
1607+
for message in messages:
1608+
if isinstance(message, AIMessage):
1609+
if message.tool_calls:
1610+
for tool_call in message.tool_calls:
1611+
tool_call_id = tool_call.get("id")
1612+
if tool_call_id:
1613+
valid_tool_call_ids.add(tool_call_id)
1614+
if isinstance(message.content, list):
1615+
for block in message.content:
1616+
if (
1617+
isinstance(block, dict)
1618+
and block.get("type") == "tool_use"
1619+
and block.get("id")
1620+
):
1621+
valid_tool_call_ids.add(block["id"])
1622+
1623+
cleaned_messages: list[BaseMessage] = []
1624+
for message in messages:
1625+
if isinstance(message, ToolMessage) and (
1626+
not valid_tool_call_ids
1627+
or message.tool_call_id not in valid_tool_call_ids
1628+
):
1629+
continue
1630+
cleaned_messages.append(message)
1631+
return cleaned_messages
15981632

15991633

16001634
_MSG_CHUNK_MAP: dict[type[BaseMessage], type[BaseMessageChunk]] = {

libs/core/tests/unit_tests/messages/test_utils.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,84 @@ def test_trim_messages_last_30_include_system_allow_partial_end_on_human() -> No
393393
assert _MESSAGES_TO_TRIM == _MESSAGES_TO_TRIM_COPY
394394

395395

396+
def test_trim_messages_last_removes_orphaned_tool_message() -> None:
397+
messages = [
398+
HumanMessage("What's the weather in Florida?"),
399+
AIMessage(
400+
[
401+
{"type": "text", "text": "Let's check the weather in Florida"},
402+
{
403+
"type": "tool_use",
404+
"id": "abc123",
405+
"name": "get_weather",
406+
"input": {"location": "Florida"},
407+
},
408+
],
409+
tool_calls=[
410+
{
411+
"name": "get_weather",
412+
"args": {"location": "Florida"},
413+
"id": "abc123",
414+
"type": "tool_call",
415+
}
416+
],
417+
),
418+
ToolMessage("It's sunny.", name="get_weather", tool_call_id="abc123"),
419+
HumanMessage("I see"),
420+
AIMessage("Do you want to know anything else?"),
421+
HumanMessage("No, thanks"),
422+
AIMessage("You're welcome! Have a great day!"),
423+
]
424+
425+
trimmed = trim_messages(
426+
messages,
427+
strategy="last",
428+
token_counter=len,
429+
max_tokens=5,
430+
)
431+
432+
expected = [
433+
HumanMessage("I see"),
434+
AIMessage("Do you want to know anything else?"),
435+
HumanMessage("No, thanks"),
436+
AIMessage("You're welcome! Have a great day!"),
437+
]
438+
439+
assert trimmed == expected
440+
assert _MESSAGES_TO_TRIM == _MESSAGES_TO_TRIM_COPY
441+
442+
443+
def test_trim_messages_last_preserves_tool_message_when_call_present() -> None:
444+
messages = [
445+
HumanMessage("Start"),
446+
AIMessage(
447+
"Sure, let me check",
448+
tool_calls=[
449+
{
450+
"name": "search",
451+
"args": {"query": "status"},
452+
"id": "tool-1",
453+
"type": "tool_call",
454+
}
455+
],
456+
),
457+
ToolMessage("All systems operational", tool_call_id="tool-1"),
458+
HumanMessage("Thanks"),
459+
]
460+
461+
trimmed = trim_messages(
462+
messages,
463+
strategy="last",
464+
token_counter=len,
465+
max_tokens=3,
466+
)
467+
468+
expected = messages[1:]
469+
470+
assert trimmed == expected
471+
assert _MESSAGES_TO_TRIM == _MESSAGES_TO_TRIM_COPY
472+
473+
396474
def test_trim_messages_last_40_include_system_allow_partial_start_on_human() -> None:
397475
expected = [
398476
SystemMessage("This is a 4 token text."),

libs/core/tests/unit_tests/prompts/test_prompt.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
"""Test functionality related to prompts."""
22

33
import re
4+
from pathlib import Path
45
from tempfile import NamedTemporaryFile
5-
from typing import Any, Literal, Union
6+
from typing import Any, Literal, Union, cast
67
from unittest import mock
78

89
import pytest
@@ -18,6 +19,11 @@
1819
PYDANTIC_VERSION_AT_LEAST_29 = version.parse("2.9") <= PYDANTIC_VERSION
1920

2021

22+
def _normalize_blank_lines(text: str) -> str:
23+
"""Collapse whitespace-only lines to simplify cross-platform comparisons."""
24+
return "\n".join("" if not line.strip() else line for line in text.splitlines())
25+
26+
2127
def test_prompt_valid() -> None:
2228
"""Test prompts can be constructed."""
2329
template = "This is a {foo} test."
@@ -33,19 +39,22 @@ def test_from_file_encoding() -> None:
3339
input_variables = ["foo"]
3440

3541
# First write to a file using CP-1252 encoding.
36-
with NamedTemporaryFile(delete=True, mode="w", encoding="cp1252") as f:
42+
with NamedTemporaryFile(delete=False, mode="w", encoding="cp1252") as f:
3743
f.write(template)
3844
f.flush()
39-
file_name = f.name
45+
file_path = Path(f.name)
4046

47+
try:
4148
# Now read from the file using CP-1252 encoding and test
42-
prompt = PromptTemplate.from_file(file_name, encoding="cp1252")
49+
prompt = PromptTemplate.from_file(file_path, encoding="cp1252")
4350
assert prompt.template == template
4451
assert prompt.input_variables == input_variables
4552

4653
# Now read from the file using UTF-8 encoding and test
4754
with pytest.raises(UnicodeDecodeError):
48-
PromptTemplate.from_file(file_name, encoding="utf-8")
55+
PromptTemplate.from_file(file_path, encoding="utf-8")
56+
finally:
57+
file_path.unlink(missing_ok=True)
4958

5059

5160
def test_prompt_from_template() -> None:
@@ -216,10 +225,12 @@ def test_mustache_prompt_from_template(snapshot: SnapshotAssertion) -> None:
216225
{{bar}}
217226
{{/foo}}is a test."""
218227
prompt = PromptTemplate.from_template(template, template_format="mustache")
219-
assert prompt.format(foo=[{"bar": "yo"}, {"bar": "hello"}]) == (
228+
assert _normalize_blank_lines(
229+
prompt.format(foo=[{"bar": "yo"}, {"bar": "hello"}])
230+
) == (
220231
"""This
221232
yo
222-
233+
223234
hello
224235
is a test.""" # noqa: W293
225236
)
@@ -347,7 +358,7 @@ def test_prompt_invalid_template_format() -> None:
347358
PromptTemplate(
348359
input_variables=input_variables,
349360
template=template,
350-
template_format="bar",
361+
template_format=cast(PromptTemplateFormat, "bar"),
351362
)
352363

353364

@@ -681,7 +692,7 @@ def test_prompt_with_template_variable_name_jinja2() -> None:
681692

682693
def test_prompt_template_add_with_with_another_format() -> None:
683694
with pytest.raises(ValueError, match=r"Cannot add templates"):
684-
(
695+
_ = (
685696
PromptTemplate.from_template("This is a {template}")
686697
+ PromptTemplate.from_template("So {{this}} is", template_format="mustache")
687698
)

libs/core/tests/unit_tests/test_imports.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import concurrent.futures
22
import importlib
33
import subprocess
4+
import sys
45
from pathlib import Path
56

67

@@ -26,7 +27,7 @@ def try_to_import(module_name: str) -> tuple[int, str]:
2627
getattr(module, cls_)
2728

2829
result = subprocess.run(
29-
["python", "-c", f"import langchain_core.{module_name}"], check=True
30+
[sys.executable, "-c", f"import langchain_core.{module_name}"], check=True
3031
)
3132
return result.returncode, module_name
3233

0 commit comments

Comments
 (0)