@@ -433,30 +433,30 @@ def test_graph_state(mock_client):
433433 graph = StateGraph (FakeGraphState )
434434 graph .add_node (
435435 "fake_plain" ,
436- lambda state : (
437- {
438- "messages" : [
439- * state ["messages" ],
440- AIMessage (content = "Let's explore bar." ),
441- ],
442- "xyz" : "abc" ,
443- }
444- ),
436+ lambda state : {
437+ "messages" : [
438+ * state ["messages" ],
439+ AIMessage (content = "Let's explore bar." ),
440+ ],
441+ "xyz" : "abc" ,
442+ },
443+ )
444+ intermediate_chain = ChatPromptTemplate .from_messages (
445+ [("user" , "Question: What's a bar?" )]
446+ ) | FakeMessagesListChatModel (
447+ responses = [
448+ AIMessage (content = "It's a type of greeble." ),
449+ ]
445450 )
446451 graph .add_node (
447452 "fake_llm" ,
448- lambda state : (
449- ChatPromptTemplate .from_messages ([("user" , "Foo" )])
450- | FakeMessagesListChatModel (
451- responses = [
452- * state ["messages" ],
453- AIMessage (content = "It's a type of greeble." ),
454- ]
455- )
456- ).invoke (
457- state ,
458- config = config ,
459- ),
453+ lambda state : {
454+ "messages" : [
455+ * state ["messages" ],
456+ intermediate_chain .invoke (state ),
457+ ],
458+ "xyz" : state ["xyz" ],
459+ },
460460 )
461461 graph .add_edge (START , "fake_plain" )
462462 graph .add_edge ("fake_plain" , "fake_llm" )
@@ -467,28 +467,34 @@ def test_graph_state(mock_client):
467467 config = config ,
468468 )
469469
470- assert len (result ["messages" ]) == 2
470+ assert len (result ["messages" ]) == 3
471471 assert isinstance (result ["messages" ][0 ], HumanMessage )
472472 assert result ["messages" ][0 ].content == "What's a bar?"
473473 assert isinstance (result ["messages" ][1 ], AIMessage )
474474 assert result ["messages" ][1 ].content == "Let's explore bar."
475+ assert isinstance (result ["messages" ][2 ], AIMessage )
476+ assert result ["messages" ][2 ].content == "It's a type of greeble."
475477
476- assert mock_client .capture .call_count == 3
478+ assert mock_client .capture .call_count == 2
477479 generation_args = mock_client .capture .call_args_list [0 ][1 ]
478- trace_args = mock_client .capture .call_args_list [2 ][1 ]
480+ trace_args = mock_client .capture .call_args_list [1 ][1 ]
479481 assert generation_args ["event" ] == "$ai_generation"
480482 assert trace_args ["event" ] == "$ai_trace"
481483 assert trace_args ["properties" ]["$ai_trace_name" ] == "LangGraph"
484+
482485 assert len (trace_args ["properties" ]["$ai_input_state" ]["messages" ]) == 1
483486 assert isinstance (trace_args ["properties" ]["$ai_input_state" ]["messages" ][0 ], HumanMessage )
484487 assert trace_args ["properties" ]["$ai_input_state" ]["messages" ][0 ].content == "What's a bar?"
485488 assert trace_args ["properties" ]["$ai_input_state" ]["messages" ][0 ].type == "human"
486489 assert trace_args ["properties" ]["$ai_input_state" ]["xyz" ] is None
487- assert len (trace_args ["properties" ]["$ai_output_state" ]["messages" ]) == 2
488- assert isinstance (trace_args ["properties" ]["$ai_output_state" ]["messages" ][0 ], HumanMessage ) # FIXME
490+ assert len (trace_args ["properties" ]["$ai_output_state" ]["messages" ]) == 3
491+
492+ assert isinstance (trace_args ["properties" ]["$ai_output_state" ]["messages" ][0 ], HumanMessage )
489493 assert trace_args ["properties" ]["$ai_output_state" ]["messages" ][0 ].content == "What's a bar?"
490494 assert isinstance (trace_args ["properties" ]["$ai_output_state" ]["messages" ][1 ], AIMessage )
491495 assert trace_args ["properties" ]["$ai_output_state" ]["messages" ][1 ].content == "Let's explore bar."
496+ assert isinstance (trace_args ["properties" ]["$ai_output_state" ]["messages" ][2 ], AIMessage )
497+ assert trace_args ["properties" ]["$ai_output_state" ]["messages" ][2 ].content == "It's a type of greeble."
492498 assert trace_args ["properties" ]["$ai_output_state" ]["xyz" ] == "abc"
493499
494500
0 commit comments