diff --git a/instrumentation/elastic-opentelemetry-instrumentation-openai/src/opentelemetry/instrumentation/openai/__init__.py b/instrumentation/elastic-opentelemetry-instrumentation-openai/src/opentelemetry/instrumentation/openai/__init__.py index 8fab8c1..05d84d7 100644 --- a/instrumentation/elastic-opentelemetry-instrumentation-openai/src/opentelemetry/instrumentation/openai/__init__.py +++ b/instrumentation/elastic-opentelemetry-instrumentation-openai/src/opentelemetry/instrumentation/openai/__init__.py @@ -20,8 +20,7 @@ from timeit import default_timer from typing import Collection -import openai -from wrapt import wrap_function_wrapper +from wrapt import register_post_import_hook, wrap_function_wrapper from opentelemetry.instrumentation.instrumentor import BaseInstrumentor from opentelemetry.instrumentation.utils import unwrap @@ -83,6 +82,9 @@ def _instrument(self, **kwargs): self.token_usage_metric = create_gen_ai_client_token_usage(self.meter) self.operation_duration_metric = create_gen_ai_client_operation_duration(self.meter) + register_post_import_hook(self._patch, "openai") + + def _patch(self, _module): wrap_function_wrapper( "openai.resources.chat.completions", "Completions.create", @@ -95,6 +97,10 @@ def _instrument(self, **kwargs): ) def _uninstrument(self, **kwargs): + # unwrap only supports uninstrementing real module references so we + # import here. + import openai + unwrap(openai.resources.chat.completions.Completions, "create") unwrap(openai.resources.chat.completions.AsyncCompletions, "create") diff --git a/instrumentation/elastic-opentelemetry-instrumentation-openai/src/opentelemetry/instrumentation/openai/helpers.py b/instrumentation/elastic-opentelemetry-instrumentation-openai/src/opentelemetry/instrumentation/openai/helpers.py index c0f9005..90b44e3 100644 --- a/instrumentation/elastic-opentelemetry-instrumentation-openai/src/opentelemetry/instrumentation/openai/helpers.py +++ b/instrumentation/elastic-opentelemetry-instrumentation-openai/src/opentelemetry/instrumentation/openai/helpers.py @@ -17,8 +17,8 @@ import json from collections.abc import Iterable from timeit import default_timer +from typing import TYPE_CHECKING -from openai.types import CompletionUsage from opentelemetry.semconv.attributes.error_attributes import ERROR_TYPE from opentelemetry.semconv.attributes.server_attributes import SERVER_ADDRESS, SERVER_PORT from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import ( @@ -41,6 +41,11 @@ from opentelemetry.metrics import Histogram from opentelemetry.trace import Span +if TYPE_CHECKING: + from openai.types import CompletionUsage +else: + CompletionUsage = None + def _set_span_attributes_from_response( span: Span, response_id: str, model: str, choices, usage: CompletionUsage