Skip to content

Commit d7da8f4

Browse files
committed
Fix graph test
1 parent 93977f9 commit d7da8f4

File tree

1 file changed

+32
-26
lines changed

1 file changed

+32
-26
lines changed

posthog/test/ai/langchain/test_callbacks.py

Lines changed: 32 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -433,30 +433,30 @@ def test_graph_state(mock_client):
433433
graph = StateGraph(FakeGraphState)
434434
graph.add_node(
435435
"fake_plain",
436-
lambda state: (
437-
{
438-
"messages": [
439-
*state["messages"],
440-
AIMessage(content="Let's explore bar."),
441-
],
442-
"xyz": "abc",
443-
}
444-
),
436+
lambda state: {
437+
"messages": [
438+
*state["messages"],
439+
AIMessage(content="Let's explore bar."),
440+
],
441+
"xyz": "abc",
442+
},
443+
)
444+
intermediate_chain = ChatPromptTemplate.from_messages(
445+
[("user", "Question: What's a bar?")]
446+
) | FakeMessagesListChatModel(
447+
responses=[
448+
AIMessage(content="It's a type of greeble."),
449+
]
445450
)
446451
graph.add_node(
447452
"fake_llm",
448-
lambda state: (
449-
ChatPromptTemplate.from_messages([("user", "Foo")])
450-
| FakeMessagesListChatModel(
451-
responses=[
452-
*state["messages"],
453-
AIMessage(content="It's a type of greeble."),
454-
]
455-
)
456-
).invoke(
457-
state,
458-
config=config,
459-
),
453+
lambda state: {
454+
"messages": [
455+
*state["messages"],
456+
intermediate_chain.invoke(state),
457+
],
458+
"xyz": state["xyz"],
459+
},
460460
)
461461
graph.add_edge(START, "fake_plain")
462462
graph.add_edge("fake_plain", "fake_llm")
@@ -467,28 +467,34 @@ def test_graph_state(mock_client):
467467
config=config,
468468
)
469469

470-
assert len(result["messages"]) == 2
470+
assert len(result["messages"]) == 3
471471
assert isinstance(result["messages"][0], HumanMessage)
472472
assert result["messages"][0].content == "What's a bar?"
473473
assert isinstance(result["messages"][1], AIMessage)
474474
assert result["messages"][1].content == "Let's explore bar."
475+
assert isinstance(result["messages"][2], AIMessage)
476+
assert result["messages"][2].content == "It's a type of greeble."
475477

476-
assert mock_client.capture.call_count == 3
478+
assert mock_client.capture.call_count == 2
477479
generation_args = mock_client.capture.call_args_list[0][1]
478-
trace_args = mock_client.capture.call_args_list[2][1]
480+
trace_args = mock_client.capture.call_args_list[1][1]
479481
assert generation_args["event"] == "$ai_generation"
480482
assert trace_args["event"] == "$ai_trace"
481483
assert trace_args["properties"]["$ai_trace_name"] == "LangGraph"
484+
482485
assert len(trace_args["properties"]["$ai_input_state"]["messages"]) == 1
483486
assert isinstance(trace_args["properties"]["$ai_input_state"]["messages"][0], HumanMessage)
484487
assert trace_args["properties"]["$ai_input_state"]["messages"][0].content == "What's a bar?"
485488
assert trace_args["properties"]["$ai_input_state"]["messages"][0].type == "human"
486489
assert trace_args["properties"]["$ai_input_state"]["xyz"] is None
487-
assert len(trace_args["properties"]["$ai_output_state"]["messages"]) == 2
488-
assert isinstance(trace_args["properties"]["$ai_output_state"]["messages"][0], HumanMessage) # FIXME
490+
assert len(trace_args["properties"]["$ai_output_state"]["messages"]) == 3
491+
492+
assert isinstance(trace_args["properties"]["$ai_output_state"]["messages"][0], HumanMessage)
489493
assert trace_args["properties"]["$ai_output_state"]["messages"][0].content == "What's a bar?"
490494
assert isinstance(trace_args["properties"]["$ai_output_state"]["messages"][1], AIMessage)
491495
assert trace_args["properties"]["$ai_output_state"]["messages"][1].content == "Let's explore bar."
496+
assert isinstance(trace_args["properties"]["$ai_output_state"]["messages"][2], AIMessage)
497+
assert trace_args["properties"]["$ai_output_state"]["messages"][2].content == "It's a type of greeble."
492498
assert trace_args["properties"]["$ai_output_state"]["xyz"] == "abc"
493499

494500

0 commit comments

Comments
 (0)