Skip to content

Commit a7b4add

Browse files
committed
Working OK patching the generate_content() method
1 parent 8fa6dda commit a7b4add

File tree

4 files changed

+250
-4
lines changed

4 files changed

+250
-4
lines changed

instrumentation-genai/opentelemetry-instrumentation-vertexai/examples/zero-code/main.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import vertexai
2-
from vertexai.generative_models import GenerativeModel
2+
from vertexai.generative_models import GenerationConfig, GenerativeModel
33

44

55
def main():
66
vertexai.init()
77
model = GenerativeModel("gemini-1.5-flash-002")
88
chat_completion = model.generate_content(
9-
"Write a short poem on OpenTelemetry."
9+
"Write a short poem on OpenTelemetry.",
10+
generation_config=GenerationConfig(
11+
top_k=2, top_p=0.95, temperature=0.2, stop_sequences=["\n\n\n"]
12+
),
1013
)
1114
print(chat_completion.text)
1215

instrumentation-genai/opentelemetry-instrumentation-vertexai/src/opentelemetry/instrumentation/vertexai/__init__.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,17 @@
4141

4242
from typing import Any, Collection
4343

44+
from wrapt import (
45+
wrap_function_wrapper, # type: ignore[reportUnknownVariableType]
46+
)
47+
4448
from opentelemetry._events import get_event_logger
49+
from opentelemetry.instrumentation.genai_utils import is_content_enabled
4550
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
4651
from opentelemetry.instrumentation.vertexai.package import _instruments
52+
from opentelemetry.instrumentation.vertexai.patch import (
53+
generate_content_create,
54+
)
4755
from opentelemetry.semconv.schemas import Schemas
4856
from opentelemetry.trace import get_tracer
4957

@@ -55,20 +63,28 @@ def instrumentation_dependencies(self) -> Collection[str]:
5563
def _instrument(self, **kwargs: Any):
5664
"""Enable VertexAI instrumentation."""
5765
tracer_provider = kwargs.get("tracer_provider")
58-
_tracer = get_tracer(
66+
tracer = get_tracer(
5967
__name__,
6068
"",
6169
tracer_provider,
6270
schema_url=Schemas.V1_28_0.value,
6371
)
6472
event_logger_provider = kwargs.get("event_logger_provider")
65-
_event_logger = get_event_logger(
73+
event_logger = get_event_logger(
6674
__name__,
6775
"",
6876
schema_url=Schemas.V1_28_0.value,
6977
event_logger_provider=event_logger_provider,
7078
)
7179
# TODO: implemented in later PR
7280

81+
wrap_function_wrapper(
82+
module="vertexai.generative_models._generative_models",
83+
name="_GenerativeModel.generate_content",
84+
wrapper=generate_content_create(
85+
tracer, event_logger, is_content_enabled()
86+
),
87+
)
88+
7389
def _uninstrument(self, **kwargs: Any) -> None:
7490
"""TODO: implemented in later PR"""

instrumentation-genai/opentelemetry-instrumentation-vertexai/src/opentelemetry/instrumentation/vertexai/patch.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,102 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from typing import (
18+
TYPE_CHECKING,
19+
Any,
20+
Callable,
21+
Optional,
22+
)
23+
24+
from opentelemetry._events import EventLogger
25+
from opentelemetry.instrumentation.genai_utils import (
26+
get_span_name, # type: ignore[reportUnknownVariableType]
27+
handle_span_exception, # type: ignore[reportUnknownVariableType]
28+
)
29+
from opentelemetry.instrumentation.vertexai.utils import (
30+
GenerateContentParams,
31+
get_genai_request_attributes,
32+
)
33+
from opentelemetry.trace import SpanKind, Tracer
34+
35+
if TYPE_CHECKING:
36+
from vertexai.generative_models import Tool, ToolConfig
37+
from vertexai.generative_models._generative_models import (
38+
ContentsType,
39+
GenerationConfigType,
40+
SafetySettingsType,
41+
_GenerativeModel,
42+
)
43+
44+
45+
def generate_content_create(
46+
tracer: Tracer, event_logger: EventLogger, capture_content: bool
47+
):
48+
"""Wrap the `generate_content` method of the `GenerativeModel` class to trace it."""
49+
50+
def traced_method(
51+
wrapped: Callable[..., Any],
52+
instance: _GenerativeModel,
53+
args: Any,
54+
kwargs: Any,
55+
):
56+
# Use exact signature to handle named vs positional args robustly
57+
def extract_params(
58+
contents: ContentsType,
59+
*,
60+
generation_config: Optional[GenerationConfigType] = None,
61+
safety_settings: Optional[SafetySettingsType] = None,
62+
tools: Optional[list[Tool]] = None,
63+
tool_config: Optional[ToolConfig] = None,
64+
labels: Optional[dict[str, str]] = None,
65+
stream: bool = False,
66+
) -> GenerateContentParams:
67+
return GenerateContentParams(
68+
contents=contents,
69+
generation_config=generation_config,
70+
safety_settings=safety_settings,
71+
tools=tools,
72+
tool_config=tool_config,
73+
labels=labels,
74+
stream=stream,
75+
)
76+
77+
params = extract_params(*args, **kwargs)
78+
79+
span_attributes = get_genai_request_attributes(instance, params)
80+
81+
span_name = get_span_name(span_attributes)
82+
with tracer.start_as_current_span(
83+
name=span_name,
84+
kind=SpanKind.CLIENT,
85+
attributes=span_attributes,
86+
end_on_exit=False,
87+
) as span:
88+
# if span.is_recording():
89+
# for message in kwargs.get("messages", []):
90+
# event_logger.emit(
91+
# message_to_event(message, capture_content)
92+
# )
93+
94+
try:
95+
result = wrapped(*args, **kwargs)
96+
# if is_streaming(kwargs):
97+
# return StreamWrapper(
98+
# result, span, event_logger, capture_content
99+
# )
100+
101+
# if span.is_recording():
102+
# _set_response_attributes(
103+
# span, result, event_logger, capture_content
104+
# )
105+
span.end()
106+
return result
107+
108+
except Exception as error:
109+
handle_span_exception(span, error)
110+
raise
111+
112+
return traced_method
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Copyright The OpenTelemetry Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from dataclasses import dataclass
18+
from typing import (
19+
TYPE_CHECKING,
20+
Dict,
21+
List,
22+
Optional,
23+
TypedDict,
24+
cast,
25+
)
26+
27+
if TYPE_CHECKING:
28+
from vertexai.generative_models import Tool, ToolConfig
29+
from vertexai.generative_models._generative_models import (
30+
ContentsType,
31+
GenerationConfigType,
32+
SafetySettingsType,
33+
_GenerativeModel,
34+
)
35+
from opentelemetry.semconv._incubating.attributes import (
36+
gen_ai_attributes as GenAIAttributes,
37+
)
38+
39+
40+
@dataclass(frozen=True)
41+
class GenerateContentParams:
42+
contents: ContentsType
43+
generation_config: Optional[GenerationConfigType]
44+
safety_settings: Optional[SafetySettingsType]
45+
tools: Optional[List["Tool"]]
46+
tool_config: Optional["ToolConfig"]
47+
labels: Optional[Dict[str, str]]
48+
stream: bool
49+
50+
51+
class GenerationConfigDict(TypedDict, total=False):
52+
temperature: Optional[float]
53+
top_p: Optional[float]
54+
top_k: Optional[int]
55+
max_output_tokens: Optional[int]
56+
stop_sequences: Optional[List[str]]
57+
presence_penalty: Optional[float]
58+
frequency_penalty: Optional[float]
59+
seed: Optional[int]
60+
# And more fields which aren't needed yet
61+
62+
63+
def get_genai_request_attributes(
64+
# TODO: use types
65+
instance: _GenerativeModel,
66+
params: GenerateContentParams,
67+
operation_name: GenAIAttributes.GenAiOperationNameValues = GenAIAttributes.GenAiOperationNameValues.CHAT,
68+
):
69+
model = _get_model_name(instance)
70+
generation_config = _get_generation_config(instance, params)
71+
# TODO: This currently ignores constructor parameters to GenerativeModel like
72+
# instance._generation_config. Consider wrapping PredicationClient that is used under the
73+
# hood. Otherwise we need to copy all the coalescing logic between class level options and
74+
# per-call options.
75+
attributes = {
76+
GenAIAttributes.GEN_AI_OPERATION_NAME: operation_name.value,
77+
GenAIAttributes.GEN_AI_SYSTEM: GenAIAttributes.GenAiSystemValues.VERTEX_AI.value,
78+
GenAIAttributes.GEN_AI_REQUEST_MODEL: model,
79+
GenAIAttributes.GEN_AI_REQUEST_TEMPERATURE: generation_config.get(
80+
"temperature"
81+
),
82+
GenAIAttributes.GEN_AI_REQUEST_TOP_P: generation_config.get("top_p"),
83+
GenAIAttributes.GEN_AI_REQUEST_MAX_TOKENS: generation_config.get(
84+
"max_output_tokens"
85+
),
86+
GenAIAttributes.GEN_AI_REQUEST_PRESENCE_PENALTY: generation_config.get(
87+
"presence_penalty"
88+
),
89+
GenAIAttributes.GEN_AI_REQUEST_FREQUENCY_PENALTY: generation_config.get(
90+
"frequency_penalty"
91+
),
92+
GenAIAttributes.GEN_AI_OPENAI_REQUEST_SEED: generation_config.get(
93+
"seed"
94+
),
95+
GenAIAttributes.GEN_AI_REQUEST_STOP_SEQUENCES: generation_config.get(
96+
"stop_sequences"
97+
),
98+
}
99+
100+
# filter out None values
101+
return {k: v for k, v in attributes.items() if v is not None}
102+
103+
104+
def _get_generation_config(
105+
instance: _GenerativeModel,
106+
params: GenerateContentParams,
107+
) -> GenerationConfigDict:
108+
generation_config = params.generation_config or instance._generation_config
109+
if generation_config is None:
110+
return {}
111+
if isinstance(generation_config, dict):
112+
return cast(GenerationConfigDict, generation_config)
113+
return cast(GenerationConfigDict, generation_config.to_dict())
114+
115+
116+
_RESOURCE_PREFIX = "publishers/google/models/"
117+
118+
119+
def _get_model_name(instance: _GenerativeModel) -> str:
120+
try:
121+
model_name = instance._model_name
122+
except AttributeError:
123+
model_name = "unknown"
124+
125+
# Can use str.removeprefix() once 3.8 is dropped
126+
if model_name.startswith(_RESOURCE_PREFIX):
127+
model_name = model_name[len(_RESOURCE_PREFIX) :]
128+
return model_name

0 commit comments

Comments
 (0)