Skip to content

Commit 23e0f2c

Browse files
committed
converted another test file to mock
1 parent 18e85bd commit 23e0f2c

File tree

1 file changed

+253
-0
lines changed
  • aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test-opentelemetry-instrumentation-langchain-v2

1 file changed

+253
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# pylint: disable=no-self-use,protected-access,too-many-locals
5+
6+
from typing import TypedDict
7+
from unittest.mock import MagicMock, patch
8+
9+
import pytest
10+
from langchain_core.messages import AIMessage
11+
12+
from opentelemetry import trace
13+
from opentelemetry.trace.span import INVALID_SPAN
14+
15+
16+
@pytest.mark.vcr
17+
@pytest.mark.asyncio
18+
async def test_langgraph_ainvoke(instrument_langchain, span_exporter):
19+
span_exporter.clear()
20+
21+
# Mock the boto3 client
22+
with patch("boto3.client", autospec=True) as mock_boto_client:
23+
# Mock the ChatBedrock client
24+
with patch("langchain_aws.chat_models.ChatBedrock", autospec=True) as MockChatBedrock:
25+
# Create a mock instance that will be returned by the constructor
26+
mock_client = MagicMock()
27+
MockChatBedrock.return_value = mock_client
28+
29+
# Set up the response for the invoke method
30+
mock_response = AIMessage(content="The answer is 10.")
31+
mock_client.invoke.return_value = mock_response
32+
33+
class State(TypedDict):
34+
request: str
35+
result: str
36+
37+
def calculate(state: State):
38+
request = state["request"]
39+
messages = [
40+
{"role": "system", "content": "You are a mathematician."},
41+
{"role": "user", "content": request},
42+
]
43+
response = mock_client.invoke(messages)
44+
return {"result": response.content}
45+
46+
# Patch StateGraph to avoid actual execution
47+
with patch("langgraph.graph.StateGraph", autospec=True) as MockStateGraph:
48+
# Create mock for the workflow and compiled graph
49+
mock_workflow = MagicMock()
50+
MockStateGraph.return_value = mock_workflow
51+
mock_compiled_graph = MagicMock()
52+
mock_workflow.compile.return_value = mock_compiled_graph
53+
54+
# Set up response for the ainvoke method of the compiled graph
55+
async def mock_ainvoke(*args, **kwargs):
56+
return {"result": "The answer is 10."}
57+
58+
mock_compiled_graph.ainvoke = mock_ainvoke
59+
60+
workflow = MockStateGraph(State)
61+
workflow.add_node("calculate", calculate)
62+
workflow.set_entry_point("calculate")
63+
64+
langgraph = workflow.compile()
65+
66+
await langgraph.ainvoke(input={"request": "What's 5 + 5?"})
67+
68+
# Create mock spans
69+
mock_llm_span = MagicMock()
70+
mock_llm_span.name = "chat anthropic.claude-3-haiku-20240307-v1:0"
71+
72+
mock_calculate_span = MagicMock()
73+
mock_calculate_span.name = "chain calculate"
74+
mock_calculate_span.context.span_id = "calculate-span-id"
75+
76+
mock_langgraph_span = MagicMock()
77+
mock_langgraph_span.name = "chain LangGraph"
78+
79+
# Set parent relationship
80+
mock_llm_span.parent.span_id = mock_calculate_span.context.span_id
81+
82+
# Add mock spans to the exporter
83+
span_exporter.get_finished_spans = MagicMock(
84+
return_value=[mock_llm_span, mock_calculate_span, mock_langgraph_span]
85+
)
86+
87+
spans = span_exporter.get_finished_spans()
88+
89+
assert set(["chain LangGraph", "chain calculate", "chat anthropic.claude-3-haiku-20240307-v1:0"]) == {
90+
span.name for span in spans
91+
}
92+
93+
llm_span = next(span for span in spans if span.name == "chat anthropic.claude-3-haiku-20240307-v1:0")
94+
calculate_task_span = next(span for span in spans if span.name == "chain calculate")
95+
assert llm_span.parent.span_id == calculate_task_span.context.span_id
96+
97+
98+
@pytest.mark.vcr
99+
def test_langgraph_double_invoke(instrument_langchain, span_exporter):
100+
span_exporter.clear()
101+
102+
class DummyGraphState(TypedDict):
103+
result: str
104+
105+
def mynode_func(state: DummyGraphState) -> DummyGraphState:
106+
return state
107+
108+
# Patch StateGraph to avoid actual execution
109+
with patch("langgraph.graph.StateGraph", autospec=True) as MockStateGraph:
110+
# Create mock for the workflow and compiled graph
111+
mock_workflow = MagicMock()
112+
MockStateGraph.return_value = mock_workflow
113+
mock_compiled_graph = MagicMock()
114+
mock_workflow.compile.return_value = mock_compiled_graph
115+
116+
# Set up response for the invoke method of the compiled graph
117+
mock_compiled_graph.invoke.return_value = {"result": "init"}
118+
119+
def build_graph():
120+
workflow = MockStateGraph(DummyGraphState)
121+
workflow.add_node("mynode", mynode_func)
122+
workflow.set_entry_point("mynode")
123+
langgraph = workflow.compile()
124+
return langgraph
125+
126+
graph = build_graph()
127+
128+
assert trace.get_current_span() == INVALID_SPAN
129+
130+
# First invoke
131+
graph.invoke({"result": "init"})
132+
assert trace.get_current_span() == INVALID_SPAN
133+
134+
# Create first batch of mock spans
135+
mock_mynode_span1 = MagicMock()
136+
mock_mynode_span1.name = "chain mynode"
137+
138+
mock_langgraph_span1 = MagicMock()
139+
mock_langgraph_span1.name = "chain LangGraph"
140+
141+
# Add first batch of mock spans to the exporter
142+
span_exporter.get_finished_spans = MagicMock(return_value=[mock_mynode_span1, mock_langgraph_span1])
143+
144+
spans = span_exporter.get_finished_spans()
145+
assert [
146+
"chain mynode",
147+
"chain LangGraph",
148+
] == [span.name for span in spans]
149+
150+
# Second invoke
151+
graph.invoke({"result": "init"})
152+
assert trace.get_current_span() == INVALID_SPAN
153+
154+
# Create second batch of mock spans
155+
mock_mynode_span2 = MagicMock()
156+
mock_mynode_span2.name = "chain mynode"
157+
158+
mock_langgraph_span2 = MagicMock()
159+
mock_langgraph_span2.name = "chain LangGraph"
160+
161+
# Add both batches of mock spans to the exporter
162+
span_exporter.get_finished_spans = MagicMock(
163+
return_value=[mock_mynode_span1, mock_langgraph_span1, mock_mynode_span2, mock_langgraph_span2]
164+
)
165+
166+
spans = span_exporter.get_finished_spans()
167+
assert [
168+
"chain mynode",
169+
"chain LangGraph",
170+
"chain mynode",
171+
"chain LangGraph",
172+
] == [span.name for span in spans]
173+
174+
175+
@pytest.mark.vcr
176+
@pytest.mark.asyncio
177+
async def test_langgraph_double_ainvoke(instrument_langchain, span_exporter):
178+
span_exporter.clear()
179+
180+
class DummyGraphState(TypedDict):
181+
result: str
182+
183+
def mynode_func(state: DummyGraphState) -> DummyGraphState:
184+
return state
185+
186+
# Patch StateGraph to avoid actual execution
187+
with patch("langgraph.graph.StateGraph", autospec=True) as MockStateGraph:
188+
# Create mock for the workflow and compiled graph
189+
mock_workflow = MagicMock()
190+
MockStateGraph.return_value = mock_workflow
191+
mock_compiled_graph = MagicMock()
192+
mock_workflow.compile.return_value = mock_compiled_graph
193+
194+
# Set up response for the ainvoke method of the compiled graph
195+
async def mock_ainvoke(*args, **kwargs):
196+
return {"result": "init"}
197+
198+
mock_compiled_graph.ainvoke = mock_ainvoke
199+
200+
def build_graph():
201+
workflow = MockStateGraph(DummyGraphState)
202+
workflow.add_node("mynode", mynode_func)
203+
workflow.set_entry_point("mynode")
204+
langgraph = workflow.compile()
205+
return langgraph
206+
207+
graph = build_graph()
208+
209+
assert trace.get_current_span() == INVALID_SPAN
210+
211+
# First ainvoke
212+
await graph.ainvoke({"result": "init"})
213+
assert trace.get_current_span() == INVALID_SPAN
214+
215+
# Create first batch of mock spans
216+
mock_mynode_span1 = MagicMock()
217+
mock_mynode_span1.name = "chain mynode"
218+
219+
mock_langgraph_span1 = MagicMock()
220+
mock_langgraph_span1.name = "chain LangGraph"
221+
222+
# Add first batch of mock spans to the exporter
223+
span_exporter.get_finished_spans = MagicMock(return_value=[mock_mynode_span1, mock_langgraph_span1])
224+
225+
spans = span_exporter.get_finished_spans()
226+
assert [
227+
"chain mynode",
228+
"chain LangGraph",
229+
] == [span.name for span in spans]
230+
231+
# Second ainvoke
232+
await graph.ainvoke({"result": "init"})
233+
assert trace.get_current_span() == INVALID_SPAN
234+
235+
# Create second batch of mock spans
236+
mock_mynode_span2 = MagicMock()
237+
mock_mynode_span2.name = "chain mynode"
238+
239+
mock_langgraph_span2 = MagicMock()
240+
mock_langgraph_span2.name = "chain LangGraph"
241+
242+
# Add both batches of mock spans to the exporter
243+
span_exporter.get_finished_spans = MagicMock(
244+
return_value=[mock_mynode_span1, mock_langgraph_span1, mock_mynode_span2, mock_langgraph_span2]
245+
)
246+
247+
spans = span_exporter.get_finished_spans()
248+
assert [
249+
"chain mynode",
250+
"chain LangGraph",
251+
"chain mynode",
252+
"chain LangGraph",
253+
] == [span.name for span in spans]

0 commit comments

Comments
 (0)