Skip to content

Commit abaa373

Browse files
committed
Use a shared function to reduce duplicate code
1 parent 3dd15c0 commit abaa373

File tree

6 files changed

+255
-416
lines changed

6 files changed

+255
-416
lines changed

instrumentation-genai/opentelemetry-instrumentation-openai-v2/src/opentelemetry/instrumentation/openai_v2/patch.py

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515

1616
from timeit import default_timer
17-
from typing import Optional
17+
from typing import Any, Optional
1818

1919
from openai import Stream
2020

@@ -164,27 +164,24 @@ def embeddings_create(
164164
"""Wrap the `create` method of the `Embeddings` class to trace it."""
165165

166166
def traced_method(wrapped, instance, args, kwargs):
167-
span_attributes = {
168-
**get_llm_request_attributes(
169-
kwargs,
170-
instance,
171-
GenAIAttributes.GenAiOperationNameValues.EMBEDDINGS.value,
172-
)
173-
}
167+
span_attributes = get_llm_request_attributes(
168+
kwargs,
169+
instance,
170+
GenAIAttributes.GenAiOperationNameValues.EMBEDDINGS.value,
171+
)
172+
span_name = _get_embeddings_span_name(span_attributes)
173+
input_text = kwargs.get("input", "")
174174

175-
span_name = f"{span_attributes[GenAIAttributes.GEN_AI_OPERATION_NAME]} {span_attributes[GenAIAttributes.GEN_AI_REQUEST_MODEL]}"
176175
with tracer.start_as_current_span(
177176
name=span_name,
178177
kind=SpanKind.CLIENT,
179178
attributes=span_attributes,
180179
end_on_exit=True,
181180
) as span:
182-
# Store the input for later use in the response attributes
183-
input_text = kwargs.get("input", "")
184-
185181
start = default_timer()
186182
result = None
187183
error_type = None
184+
188185
try:
189186
result = wrapped(*args, **kwargs)
190187

@@ -199,6 +196,7 @@ def traced_method(wrapped, instance, args, kwargs):
199196
error_type = type(error).__qualname__
200197
handle_span_exception(span, error)
201198
raise
199+
202200
finally:
203201
duration = max((default_timer() - start), 0)
204202
_record_metrics(
@@ -221,27 +219,24 @@ def async_embeddings_create(
221219
"""Wrap the `create` method of the `AsyncEmbeddings` class to trace it."""
222220

223221
async def traced_method(wrapped, instance, args, kwargs):
224-
span_attributes = {
225-
**get_llm_request_attributes(
226-
kwargs,
227-
instance,
228-
GenAIAttributes.GenAiOperationNameValues.EMBEDDINGS.value,
229-
)
230-
}
222+
span_attributes = get_llm_request_attributes(
223+
kwargs,
224+
instance,
225+
GenAIAttributes.GenAiOperationNameValues.EMBEDDINGS.value,
226+
)
227+
span_name = _get_embeddings_span_name(span_attributes)
228+
input_text = kwargs.get("input", "")
231229

232-
span_name = f"{span_attributes[GenAIAttributes.GEN_AI_OPERATION_NAME]} {span_attributes[GenAIAttributes.GEN_AI_REQUEST_MODEL]}"
233230
with tracer.start_as_current_span(
234231
name=span_name,
235232
kind=SpanKind.CLIENT,
236233
attributes=span_attributes,
237234
end_on_exit=True,
238235
) as span:
239-
# Store the input for later use in the response attributes
240-
input_text = kwargs.get("input", "")
241-
242236
start = default_timer()
243237
result = None
244238
error_type = None
239+
245240
try:
246241
result = await wrapped(*args, **kwargs)
247242

@@ -256,6 +251,7 @@ async def traced_method(wrapped, instance, args, kwargs):
256251
error_type = type(error).__qualname__
257252
handle_span_exception(span, error)
258253
raise
254+
259255
finally:
260256
duration = max((default_timer() - start), 0)
261257
_record_metrics(
@@ -270,6 +266,11 @@ async def traced_method(wrapped, instance, args, kwargs):
270266
return traced_method
271267

272268

269+
def _get_embeddings_span_name(span_attributes):
270+
"""Get span name for embeddings operations."""
271+
return f"{span_attributes[GenAIAttributes.GEN_AI_OPERATION_NAME]} {span_attributes[GenAIAttributes.GEN_AI_REQUEST_MODEL]}"
272+
273+
273274
def _record_metrics(
274275
instruments: Instruments,
275276
duration: float,
@@ -390,13 +391,11 @@ def _set_response_attributes(
390391

391392

392393
def _set_embeddings_response_attributes(
393-
span,
394-
result,
394+
span: Span,
395+
result: Any,
395396
capture_content: bool,
396397
input_text: str,
397398
):
398-
"""Set attributes on the span based on the embeddings response."""
399-
# Set the model name if available
400399
set_span_attribute(
401400
span, GenAIAttributes.GEN_AI_RESPONSE_MODEL, result.model
402401
)

instrumentation-genai/opentelemetry-instrumentation-openai-v2/tests/test_async_chat_completions.py

Lines changed: 54 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,10 @@
1313
# limitations under the License.
1414
# pylint: disable=too-many-locals
1515

16-
from typing import Optional
1716

1817
import pytest
1918
from openai import APIConnectionError, AsyncOpenAI, NotFoundError
20-
from openai.resources.chat.completions import ChatCompletion
2119

22-
from opentelemetry.sdk.trace import ReadableSpan
2320
from opentelemetry.semconv._incubating.attributes import (
2421
error_attributes as ErrorAttributes,
2522
)
@@ -33,6 +30,12 @@
3330
server_attributes as ServerAttributes,
3431
)
3532

33+
from .test_utils import (
34+
assert_all_attributes,
35+
assert_log_parent,
36+
remove_none_values,
37+
)
38+
3639

3740
@pytest.mark.vcr()
3841
@pytest.mark.asyncio()
@@ -47,7 +50,14 @@ async def test_async_chat_completion_with_content(
4750
)
4851

4952
spans = span_exporter.get_finished_spans()
50-
assert_completion_attributes(spans[0], llm_model_value, response)
53+
assert_all_attributes(
54+
spans[0],
55+
llm_model_value,
56+
response.id,
57+
response.model,
58+
response.usage.prompt_tokens,
59+
response.usage.completion_tokens,
60+
)
5161

5262
logs = log_exporter.get_finished_logs()
5363
assert len(logs) == 2
@@ -81,7 +91,14 @@ async def test_async_chat_completion_no_content(
8191
)
8292

8393
spans = span_exporter.get_finished_spans()
84-
assert_completion_attributes(spans[0], llm_model_value, response)
94+
assert_all_attributes(
95+
spans[0],
96+
llm_model_value,
97+
response.id,
98+
response.model,
99+
response.usage.prompt_tokens,
100+
response.usage.completion_tokens,
101+
)
85102

86103
logs = log_exporter.get_finished_logs()
87104
assert len(logs) == 2
@@ -162,7 +179,14 @@ async def test_async_chat_completion_extra_params(
162179
)
163180

164181
spans = span_exporter.get_finished_spans()
165-
assert_completion_attributes(spans[0], llm_model_value, response)
182+
assert_all_attributes(
183+
spans[0],
184+
llm_model_value,
185+
response.id,
186+
response.model,
187+
response.usage.prompt_tokens,
188+
response.usage.completion_tokens,
189+
)
166190
assert (
167191
spans[0].attributes[GenAIAttributes.GEN_AI_OPENAI_REQUEST_SEED] == 42
168192
)
@@ -195,7 +219,14 @@ async def test_async_chat_completion_multiple_choices(
195219
)
196220

197221
spans = span_exporter.get_finished_spans()
198-
assert_completion_attributes(spans[0], llm_model_value, response)
222+
assert_all_attributes(
223+
spans[0],
224+
llm_model_value,
225+
response.id,
226+
response.model,
227+
response.usage.prompt_tokens,
228+
response.usage.completion_tokens,
229+
)
199230

200231
logs = log_exporter.get_finished_logs()
201232
assert len(logs) == 3 # 1 user message + 2 choice messages
@@ -302,8 +333,22 @@ async def chat_completion_tool_call(
302333
# validate both calls
303334
spans = span_exporter.get_finished_spans()
304335
assert len(spans) == 2
305-
assert_completion_attributes(spans[0], llm_model_value, response_0)
306-
assert_completion_attributes(spans[1], llm_model_value, response_1)
336+
assert_all_attributes(
337+
spans[0],
338+
llm_model_value,
339+
response_0.id,
340+
response_0.model,
341+
response_0.usage.prompt_tokens,
342+
response_0.usage.completion_tokens,
343+
)
344+
assert_all_attributes(
345+
spans[1],
346+
llm_model_value,
347+
response_1.id,
348+
response_1.model,
349+
response_1.usage.prompt_tokens,
350+
response_1.usage.completion_tokens,
351+
)
307352

308353
logs = log_exporter.get_finished_logs()
309354
assert len(logs) == 9 # 3 logs for first completion, 6 for second
@@ -813,106 +858,6 @@ def assert_message_in_logs(log, event_name, expected_content, parent_span):
813858
assert_log_parent(log, parent_span)
814859

815860

816-
def remove_none_values(body):
817-
result = {}
818-
for key, value in body.items():
819-
if value is None:
820-
continue
821-
if isinstance(value, dict):
822-
result[key] = remove_none_values(value)
823-
elif isinstance(value, list):
824-
result[key] = [remove_none_values(i) for i in value]
825-
else:
826-
result[key] = value
827-
return result
828-
829-
830-
def assert_completion_attributes(
831-
span: ReadableSpan,
832-
request_model: str,
833-
response: ChatCompletion,
834-
operation_name: str = "chat",
835-
server_address: str = "api.openai.com",
836-
):
837-
return assert_all_attributes(
838-
span,
839-
request_model,
840-
response.id,
841-
response.model,
842-
response.usage.prompt_tokens,
843-
response.usage.completion_tokens,
844-
operation_name,
845-
server_address,
846-
)
847-
848-
849-
def assert_all_attributes(
850-
span: ReadableSpan,
851-
request_model: str,
852-
response_id: str = None,
853-
response_model: str = None,
854-
input_tokens: Optional[int] = None,
855-
output_tokens: Optional[int] = None,
856-
operation_name: str = "chat",
857-
server_address: str = "api.openai.com",
858-
):
859-
assert span.name == f"{operation_name} {request_model}"
860-
assert (
861-
operation_name
862-
== span.attributes[GenAIAttributes.GEN_AI_OPERATION_NAME]
863-
)
864-
assert (
865-
GenAIAttributes.GenAiSystemValues.OPENAI.value
866-
== span.attributes[GenAIAttributes.GEN_AI_SYSTEM]
867-
)
868-
assert (
869-
request_model == span.attributes[GenAIAttributes.GEN_AI_REQUEST_MODEL]
870-
)
871-
if response_model:
872-
assert (
873-
response_model
874-
== span.attributes[GenAIAttributes.GEN_AI_RESPONSE_MODEL]
875-
)
876-
else:
877-
assert GenAIAttributes.GEN_AI_RESPONSE_MODEL not in span.attributes
878-
879-
if response_id:
880-
assert (
881-
response_id == span.attributes[GenAIAttributes.GEN_AI_RESPONSE_ID]
882-
)
883-
else:
884-
assert GenAIAttributes.GEN_AI_RESPONSE_ID not in span.attributes
885-
886-
if input_tokens:
887-
assert (
888-
input_tokens
889-
== span.attributes[GenAIAttributes.GEN_AI_USAGE_INPUT_TOKENS]
890-
)
891-
else:
892-
assert GenAIAttributes.GEN_AI_USAGE_INPUT_TOKENS not in span.attributes
893-
894-
if output_tokens:
895-
assert (
896-
output_tokens
897-
== span.attributes[GenAIAttributes.GEN_AI_USAGE_OUTPUT_TOKENS]
898-
)
899-
else:
900-
assert (
901-
GenAIAttributes.GEN_AI_USAGE_OUTPUT_TOKENS not in span.attributes
902-
)
903-
904-
assert server_address == span.attributes[ServerAttributes.SERVER_ADDRESS]
905-
906-
907-
def assert_log_parent(log, span):
908-
if span:
909-
assert log.log_record.trace_id == span.get_span_context().trace_id
910-
assert log.log_record.span_id == span.get_span_context().span_id
911-
assert (
912-
log.log_record.trace_flags == span.get_span_context().trace_flags
913-
)
914-
915-
916861
def get_current_weather_tool_definition():
917862
return {
918863
"type": "function",

0 commit comments

Comments
 (0)