Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 136 additions & 6 deletions libs/community/langchain_community/chat_models/mlx.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
"""MLX Chat Wrapper."""

import json
import logging
import re
import uuid
from typing import (
Any,
Callable,
Expand All @@ -9,8 +13,10 @@
Literal,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)

from langchain_core.callbacks.manager import (
Expand All @@ -24,7 +30,13 @@
AIMessageChunk,
BaseMessage,
HumanMessage,
InvalidToolCall,
SystemMessage,
ToolCall,
)
from langchain_core.output_parsers.openai_tools import (
make_invalid_tool_call,
parse_tool_call,
)
from langchain_core.outputs import (
ChatGeneration,
Expand All @@ -38,9 +50,58 @@

from langchain_community.llms.mlx_pipeline import MLXPipeline

logger = logging.getLogger(__name__)

DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant."""


def _parse_react_tool_calls(
text: str,
) -> Tuple[list[ToolCall] | None, list[InvalidToolCall]]:
"""Extract ReAct-style tool calls from plain text output.

Args:
text: Raw model generation text.

Returns:
A tuple containing a list of parsed ``ToolCall`` objects if any were
detected, otherwise ``None``, and a list of ``InvalidToolCall`` objects
for unparseable patterns.
"""

tool_calls: list[ToolCall] = []
invalid_tool_calls: list[InvalidToolCall] = []

bracket_pattern = r"Action:\s*(?P<name>[\w.-]+)\[(?P<input>[^\]]+)\]"
separate_pattern = r"Action:\s*(?P<name>[^\n]+)\nAction Input:\s*(?P<input>[^\n]+)"

matches = list(re.finditer(bracket_pattern, text))
if not matches:
matches = list(re.finditer(separate_pattern, text))

for match in matches:
name = match.group("name").strip()
arg_text = match.group("input").strip()
try:
args = json.loads(arg_text)
if not isinstance(args, dict):
args = {"input": args}
except Exception:
args = {"input": arg_text}
tool_calls.append(ToolCall(id=str(uuid.uuid4()), name=name, args=args))

if not tool_calls and "Action:" in text:
invalid_tool_calls.append(
make_invalid_tool_call(
{"name": "unknown", "args": text},
"Could not parse ReAct tool call",
)
)
return None, invalid_tool_calls

return tool_calls or None, invalid_tool_calls


class ChatMLX(BaseChatModel):
"""MLX chat models.

Expand Down Expand Up @@ -69,14 +130,34 @@ def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self.tokenizer = self.llm.tokenizer

def _parse_tool_args(self, arg_text: str) -> Dict[str, Any]:
"""Parse the arguments for a tool call.

Args:
arg_text: JSON string representation of the tool arguments.

Returns:
Parsed arguments dictionary. If parsing fails, returns a dict with
the original text under the ``input`` key.
"""
try:
args = json.loads(arg_text)
except json.JSONDecodeError:
args = {"input": arg_text}
except Exception as e: # pragma: no cover - defensive
logger.warning("Unexpected error during tool argument parsing: %s", e)
args = {"input": arg_text}
return args

def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
llm_input = self._to_chat_prompt(messages)
tools = kwargs.pop("tools", None)
llm_input = self._to_chat_prompt(messages, tools=tools)
llm_result = self.llm._generate(
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
)
Expand All @@ -89,7 +170,8 @@ async def _agenerate(
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
llm_input = self._to_chat_prompt(messages)
tools = kwargs.pop("tools", None)
llm_input = self._to_chat_prompt(messages, tools=tools)
llm_result = await self.llm._agenerate(
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
)
Expand All @@ -100,8 +182,17 @@ def _to_chat_prompt(
messages: List[BaseMessage],
tokenize: bool = False,
return_tensors: Optional[str] = None,
tools: Sequence[dict] | None = None,
) -> str:
"""Convert a list of messages into a prompt format expected by wrapped LLM."""
"""Convert messages to the prompt format expected by the wrapped LLM.

Args:
messages: Chat messages to include in the prompt.
tokenize: Whether to return token IDs instead of text.
return_tensors: Framework for returned tensors when ``tokenize`` is
True.
tools: Optional tool definitions to include in the prompt.
"""
if not messages:
raise ValueError("At least one HumanMessage must be provided!")

Expand All @@ -114,6 +205,7 @@ def _to_chat_prompt(
tokenize=tokenize,
add_generation_prompt=True,
return_tensors=return_tensors,
tools=tools,
)

def _to_chatml_format(self, message: BaseMessage) -> dict:
Expand All @@ -135,8 +227,41 @@ def _to_chat_result(llm_result: LLMResult) -> ChatResult:
chat_generations = []

for g in llm_result.generations[0]:
tool_calls: list[ToolCall] = []
invalid_tool_calls: list[InvalidToolCall] = []
additional_kwargs: Dict[str, Any] = {}

if isinstance(g.generation_info, dict):
raw_tool_calls = g.generation_info.get("tool_calls")
else:
raw_tool_calls = None

if raw_tool_calls:
additional_kwargs["tool_calls"] = raw_tool_calls
for raw_tool_call in raw_tool_calls:
try:
tc = parse_tool_call(raw_tool_call, return_id=True)
except Exception as e:
invalid_tool_calls.append(
make_invalid_tool_call(raw_tool_call, str(e))
)
else:
if tc:
tool_calls.append(tc)
else:
react_tool_calls, invalid_reacts = _parse_react_tool_calls(g.text)
if react_tool_calls is not None:
tool_calls.extend(react_tool_calls)
invalid_tool_calls.extend(invalid_reacts)

chat_generation = ChatGeneration(
message=AIMessage(content=g.text), generation_info=g.generation_info
message=AIMessage(
content=g.text,
additional_kwargs=additional_kwargs,
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
),
generation_info=g.generation_info,
)
chat_generations.append(chat_generation)

Expand Down Expand Up @@ -244,7 +369,9 @@ def bind_tools(
:class:`~langchain.runnable.Runnable` constructor.
"""

formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
formatted_tools: List[Dict[str, Any]] = [
convert_to_openai_tool(tool) for tool in tools
]
if tool_choice is not None and tool_choice:
if len(formatted_tools) != 1:
raise ValueError(
Expand Down Expand Up @@ -274,4 +401,7 @@ def bind_tools(
f"Received: {tool_choice}"
)
kwargs["tool_choice"] = tool_choice
return super().bind(tools=formatted_tools, **kwargs)
return super().bind(
tools=cast(Sequence[Dict[str, Any]], formatted_tools),
**kwargs,
)
8 changes: 6 additions & 2 deletions libs/community/langchain_community/document_loaders/rss.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,15 @@ def lazy_load(self) -> Iterator[Document]:
)
article = loader.load()[0]
article.metadata["feed"] = url
# If the publish date is not set by newspaper, try to extract it from the feed entry
# If the publish date is not set by newspaper, try to extract it
# from the feed entry
if article.metadata.get("publish_date") is None:
from datetime import datetime

publish_date = entry.get("published_parsed", None)
article.metadata["publish_date"] = datetime(*publish_date[:6]) if publish_date else None
article.metadata["publish_date"] = (
datetime(*publish_date[:6]) if publish_date else None
)
yield article
except Exception as e:
if self.continue_on_failure:
Expand Down
9 changes: 7 additions & 2 deletions libs/community/langchain_community/tools/sql_database/tool.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# flake8: noqa
"""Tools for interacting with a SQL database."""

from typing import Any, Dict, Optional, Sequence, Type, Union
from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union

from sqlalchemy.engine import Result

Expand Down Expand Up @@ -54,7 +54,12 @@ def _run(
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> Union[str, Sequence[Dict[str, Any]], Result]:
) -> Union[
str,
Sequence[Dict[str, Any]],
Sequence[Tuple[Any, ...]],
Result[Any],
]:
"""Execute the query, return the results or an error message."""
return self.db.run_no_throw(query)

Expand Down
10 changes: 5 additions & 5 deletions libs/community/langchain_community/vectorstores/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import uuid
import warnings
from typing import Any, Iterable, List, Optional, Type
from typing import Any, Iterable, List, Optional, Type, cast

from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
Expand Down Expand Up @@ -224,12 +224,12 @@ def similarity_search_pd(
list_cosine_similarity = self.duckdb.FunctionExpression(
"list_cosine_similarity",
self.duckdb.ColumnExpression(self._vector_key),
self.duckdb.ConstantExpression(embedding),
self.duckdb.ConstantExpression(cast(str, embedding)),
)
docs = (
self._table.select(
*[
self.duckdb.StarExpression(exclude=[]),
self.duckdb.StarExpression(),
list_cosine_similarity.alias(SIMILARITY_ALIAS),
]
)
Expand Down Expand Up @@ -269,12 +269,12 @@ def similarity_search(
list_cosine_similarity = self.duckdb.FunctionExpression(
"list_cosine_similarity",
self.duckdb.ColumnExpression(self._vector_key),
self.duckdb.ConstantExpression(embedding),
self.duckdb.ConstantExpression(cast(str, embedding)),
)
docs = (
self._table.select(
*[
self.duckdb.StarExpression(exclude=[]),
self.duckdb.StarExpression(),
list_cosine_similarity.alias(SIMILARITY_ALIAS),
]
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""Tests ChatMLX tool calling."""

from typing import Dict

import pytest
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.tools import tool

from langchain_community.chat_models.mlx import ChatMLX
from langchain_community.llms.mlx_pipeline import MLXPipeline

# Use a Phi-3 model for more reliable tool-calling behavior
MODEL_ID = "mlx-community/phi-3-mini-128k-instruct"


@tool
def multiply(a: int, b: int) -> int:
"""Multiply two integers."""
return a * b


@pytest.fixture(scope="module")
def chat() -> ChatMLX:
"""Return ChatMLX bound with the multiply tool or skip if unavailable."""
try:
llm = MLXPipeline.from_model_id(
model_id=MODEL_ID, pipeline_kwargs={"max_new_tokens": 150}
)
except Exception:
pytest.skip("Required MLX model isn't available.", allow_module_level=True)
chat_model = ChatMLX(llm=llm)
return chat_model.bind_tools(tools=[multiply], tool_choice=True) # type: ignore[return-value]


def _call_tool(tool_call: Dict) -> ToolMessage:
result = multiply.invoke(tool_call["args"])
return ToolMessage(content=str(result), tool_call_id=tool_call.get("id", ""))


def test_mlx_tool_calls_soft(chat: ChatMLX) -> None:
messages = [HumanMessage(content="Use the multiply tool to compute 2 * 3.")]
ai_msg = chat.invoke(messages)
tool_msg = _call_tool(ai_msg.tool_calls[0])
final = chat.invoke(messages + [ai_msg, tool_msg])
assert "6" in final.content


def test_mlx_tool_calls_hard(chat: ChatMLX) -> None:
messages = [HumanMessage(content="Use the multiply tool to compute 2 * 3.")]
ai_msg = chat.invoke(messages)
assert isinstance(ai_msg, AIMessage)
assert ai_msg.tool_calls
tool_call = ai_msg.tool_calls[0]
assert tool_call["name"] == "multiply"
assert tool_call["args"] == {"a": 2, "b": 3}
Loading