Skip to content

Commit b1cc972

Browse files
authored
core[patch]: Improve RunnableWithMessageHistory init arg types (#31639)
`Runnable`'s `Input` is contravariant so we need to enumerate all possible inputs and it's not possible to put them in a `Union`. Also, it's better to only require a runnable that accepts`list[BaseMessage]` instead of a broader `Sequence[BaseMessage]` as internally the runnable is only called with a list.
1 parent dcf5c7b commit b1cc972

File tree

4 files changed

+15
-12
lines changed

4 files changed

+15
-12
lines changed

libs/core/langchain_core/runnables/history.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,11 @@ def __init__(
241241
self,
242242
runnable: Union[
243243
Runnable[
244-
Union[MessagesOrDictWithMessages],
244+
list[BaseMessage],
245+
Union[str, BaseMessage, MessagesOrDictWithMessages],
246+
],
247+
Runnable[
248+
dict[str, Any],
245249
Union[str, BaseMessage, MessagesOrDictWithMessages],
246250
],
247251
LanguageModelLike,
@@ -258,7 +262,7 @@ def __init__(
258262
259263
Args:
260264
runnable: The base Runnable to be wrapped. Must take as input one of:
261-
1. A sequence of BaseMessages
265+
1. A list of BaseMessages
262266
2. A dict with one key for all messages
263267
3. A dict with one key for the current input string/message(s) and
264268
a separate key for historical messages. If the input key points

libs/core/tests/unit_tests/runnables/test_history.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,7 @@ def get_session_history(
553553
return store[(user_id, conversation_id)]
554554

555555
with_message_history = RunnableWithMessageHistory(
556-
runnable, # type: ignore[arg-type]
556+
runnable,
557557
get_session_history=get_session_history,
558558
input_messages_key="messages",
559559
history_messages_key="history",
@@ -666,7 +666,7 @@ def get_session_history(
666666
return store[(user_id, conversation_id)]
667667

668668
with_message_history = RunnableWithMessageHistory(
669-
runnable, # type: ignore[arg-type]
669+
runnable,
670670
get_session_history=get_session_history,
671671
input_messages_key="messages",
672672
history_messages_key="history",
@@ -769,13 +769,13 @@ def _fake_llm(messages: list[BaseMessage]) -> list[BaseMessage]:
769769

770770
runnable = RunnableLambda(_fake_llm)
771771
history = InMemoryChatMessageHistory()
772-
with_message_history = RunnableWithMessageHistory(runnable, lambda: history) # type: ignore[arg-type]
772+
with_message_history = RunnableWithMessageHistory(runnable, lambda: history)
773773
_ = with_message_history.invoke("hello")
774774
_ = with_message_history.invoke("hello again")
775775
assert len(history.messages) == 4
776776

777777

778-
class _RunnableLambdaWithRaiseError(RunnableLambda):
778+
class _RunnableLambdaWithRaiseError(RunnableLambda[Input, Output]):
779779
from langchain_core.tracers.root_listeners import AsyncListener
780780

781781
def with_listeners(
@@ -861,7 +861,7 @@ def test_get_output_messages_with_value_error() -> None:
861861
runnable = _RunnableLambdaWithRaiseError(lambda _: illegal_bool_message)
862862
store: dict = {}
863863
get_session_history = _get_get_session_history(store=store)
864-
with_history = RunnableWithMessageHistory(runnable, get_session_history)
864+
with_history = RunnableWithMessageHistory(runnable, get_session_history) # type: ignore[arg-type]
865865
config: RunnableConfig = {
866866
"configurable": {"session_id": "1", "message_history": get_session_history("1")}
867867
}
@@ -876,8 +876,8 @@ def test_get_output_messages_with_value_error() -> None:
876876
with_history.bound.invoke([HumanMessage(content="hello")], config)
877877

878878
illegal_int_message = 123
879-
runnable = _RunnableLambdaWithRaiseError(lambda _: illegal_int_message)
880-
with_history = RunnableWithMessageHistory(runnable, get_session_history)
879+
runnable2 = _RunnableLambdaWithRaiseError(lambda _: illegal_int_message)
880+
with_history = RunnableWithMessageHistory(runnable2, get_session_history) # type: ignore[arg-type]
881881

882882
with pytest.raises(
883883
ValueError,

libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from langchain_core.retrievers import BaseRetriever
2727
from langchain_core.runnables import (
2828
ConfigurableField,
29-
Runnable,
3029
RunnableConfig,
3130
RunnableLambda,
3231
)
@@ -1935,7 +1934,7 @@ def get_by_session_id(session_id: str) -> BaseChatMessageHistory:
19351934
)
19361935
model = GenericFakeChatModel(messages=infinite_cycle)
19371936

1938-
chain: Runnable = prompt | model
1937+
chain = prompt | model
19391938
with_message_history = RunnableWithMessageHistory(
19401939
chain,
19411940
get_session_history=get_by_session_id,

libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1890,7 +1890,7 @@ def get_by_session_id(session_id: str) -> BaseChatMessageHistory:
18901890
)
18911891
model = GenericFakeChatModel(messages=infinite_cycle)
18921892

1893-
chain: Runnable = prompt | model
1893+
chain = prompt | model
18941894
with_message_history = RunnableWithMessageHistory(
18951895
chain,
18961896
get_session_history=get_by_session_id,

0 commit comments

Comments
 (0)