Skip to content

Commit 3512ca3

Browse files
committed
Actually use GAPIC client since thats what we use under the hood
Also this is what LangChain uses
1 parent 66ed1de commit 3512ca3

File tree

6 files changed

+107
-229
lines changed

6 files changed

+107
-229
lines changed

instrumentation-genai/opentelemetry-instrumentation-vertexai/src/opentelemetry/instrumentation/vertexai/__init__.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,15 @@ def _instrument(self, **kwargs: Any):
7878
)
7979

8080
wrap_function_wrapper(
81-
module="vertexai.generative_models._generative_models",
82-
# Patching this base class also instruments the vertexai.preview.generative_models
83-
# package
84-
name="_GenerativeModel.generate_content",
81+
module="google.cloud.aiplatform_v1beta1.services.prediction_service.client",
82+
name="PredictionServiceClient.generate_content",
83+
wrapper=generate_content_create(
84+
tracer, event_logger, is_content_enabled()
85+
),
86+
)
87+
wrap_function_wrapper(
88+
module="google.cloud.aiplatform_v1.services.prediction_service.client",
89+
name="PredictionServiceClient.generate_content",
8590
wrapper=generate_content_create(
8691
tracer, event_logger, is_content_enabled()
8792
),

instrumentation-genai/opentelemetry-instrumentation-vertexai/src/opentelemetry/instrumentation/vertexai/patch.py

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,15 @@
1414

1515
from __future__ import annotations
1616

17-
from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional
17+
from typing import (
18+
TYPE_CHECKING,
19+
Any,
20+
Callable,
21+
Iterable,
22+
MutableSequence,
23+
Optional,
24+
Union,
25+
)
1826

1927
from opentelemetry._events import EventLogger
2028
from opentelemetry.instrumentation.vertexai.utils import (
@@ -25,41 +33,49 @@
2533
from opentelemetry.trace import SpanKind, Tracer
2634

2735
if TYPE_CHECKING:
36+
from google.cloud.aiplatform_v1.types import (
37+
content,
38+
prediction_service,
39+
)
2840
from vertexai.generative_models import (
2941
GenerationResponse,
30-
Tool,
31-
ToolConfig,
3242
)
3343
from vertexai.generative_models._generative_models import (
34-
ContentsType,
35-
GenerationConfigType,
36-
SafetySettingsType,
3744
_GenerativeModel,
3845
)
3946

4047

4148
# Use parameter signature from
42-
# https://github.com/googleapis/python-aiplatform/blob/v1.76.0/vertexai/generative_models/_generative_models.py#L595
49+
# https://github.com/googleapis/python-aiplatform/blob/v1.76.0/google/cloud/aiplatform_v1/services/prediction_service/client.py#L2088
4350
# to handle named vs positional args robustly
4451
def _extract_params(
45-
contents: ContentsType,
52+
request: Optional[
53+
Union[prediction_service.GenerateContentRequest, dict[Any, Any]]
54+
] = None,
4655
*,
47-
generation_config: Optional[GenerationConfigType] = None,
48-
safety_settings: Optional[SafetySettingsType] = None,
49-
tools: Optional[list[Tool]] = None,
50-
tool_config: Optional[ToolConfig] = None,
51-
labels: Optional[dict[str, str]] = None,
52-
stream: bool = False,
56+
model: Optional[str] = None,
57+
contents: Optional[MutableSequence[content.Content]] = None,
5358
**_kwargs: Any,
5459
) -> GenerateContentParams:
60+
# Request vs the named parameters are mututally exclusive or the RPC will fail
61+
if not request:
62+
return GenerateContentParams(
63+
model=model or "",
64+
contents=contents,
65+
)
66+
67+
if isinstance(request, dict):
68+
return GenerateContentParams(**request)
69+
5570
return GenerateContentParams(
56-
contents=contents,
57-
generation_config=generation_config,
58-
safety_settings=safety_settings,
59-
tools=tools,
60-
tool_config=tool_config,
61-
labels=labels,
62-
stream=stream,
71+
model=request.model,
72+
contents=request.contents,
73+
system_instruction=request.system_instruction,
74+
tools=request.tools,
75+
tool_config=request.tool_config,
76+
labels=request.labels,
77+
safety_settings=request.safety_settings,
78+
generation_config=request.generation_config,
6379
)
6480

6581

@@ -77,7 +93,7 @@ def traced_method(
7793
kwargs: Any,
7894
):
7995
params = _extract_params(*args, **kwargs)
80-
span_attributes = get_genai_request_attributes(instance, params)
96+
span_attributes = get_genai_request_attributes(params)
8197

8298
span_name = get_span_name(span_attributes)
8399
with tracer.start_as_current_span(

instrumentation-genai/opentelemetry-instrumentation-vertexai/src/opentelemetry/instrumentation/vertexai/utils.py

Lines changed: 56 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,14 @@
1414

1515
from __future__ import annotations
1616

17+
import re
1718
from dataclasses import dataclass
1819
from os import environ
1920
from typing import (
2021
TYPE_CHECKING,
21-
Dict,
22-
List,
2322
Mapping,
2423
Optional,
25-
TypedDict,
26-
cast,
24+
Sequence,
2725
)
2826

2927
from opentelemetry.semconv._incubating.attributes import (
@@ -32,96 +30,77 @@
3230
from opentelemetry.util.types import AttributeValue
3331

3432
if TYPE_CHECKING:
35-
from vertexai.generative_models import Tool, ToolConfig
36-
from vertexai.generative_models._generative_models import (
37-
ContentsType,
38-
GenerationConfigType,
39-
SafetySettingsType,
40-
_GenerativeModel,
41-
)
33+
from google.cloud.aiplatform_v1.types import content, tool
4234

4335

4436
@dataclass(frozen=True)
4537
class GenerateContentParams:
46-
contents: ContentsType
47-
generation_config: Optional[GenerationConfigType]
48-
safety_settings: Optional[SafetySettingsType]
49-
tools: Optional[List["Tool"]]
50-
tool_config: Optional["ToolConfig"]
51-
labels: Optional[Dict[str, str]]
52-
stream: bool
53-
54-
55-
class GenerationConfigDict(TypedDict, total=False):
56-
temperature: Optional[float]
57-
top_p: Optional[float]
58-
top_k: Optional[int]
59-
max_output_tokens: Optional[int]
60-
stop_sequences: Optional[List[str]]
61-
presence_penalty: Optional[float]
62-
frequency_penalty: Optional[float]
63-
seed: Optional[int]
64-
# And more fields which aren't needed yet
38+
model: str
39+
contents: Optional[Sequence[content.Content]] = None
40+
system_instruction: Optional[content.Content | None] = None
41+
tools: Optional[Sequence[tool.Tool]] = None
42+
tool_config: Optional[tool.ToolConfig] = None
43+
labels: Optional[Mapping[str, str]] = None
44+
safety_settings: Optional[Sequence[content.SafetySetting]] = None
45+
generation_config: Optional[content.GenerationConfig] = None
6546

6647

6748
def get_genai_request_attributes(
68-
instance: _GenerativeModel,
6949
params: GenerateContentParams,
7050
operation_name: GenAIAttributes.GenAiOperationNameValues = GenAIAttributes.GenAiOperationNameValues.CHAT,
7151
):
72-
model = _get_model_name(instance)
73-
generation_config = _get_generation_config(instance, params)
74-
attributes = {
52+
model = _get_model_name(params.model)
53+
generation_config = params.generation_config
54+
attributes: dict[str, AttributeValue] = {
7555
GenAIAttributes.GEN_AI_OPERATION_NAME: operation_name.value,
7656
GenAIAttributes.GEN_AI_SYSTEM: GenAIAttributes.GenAiSystemValues.VERTEX_AI.value,
7757
GenAIAttributes.GEN_AI_REQUEST_MODEL: model,
78-
GenAIAttributes.GEN_AI_REQUEST_TEMPERATURE: generation_config.get(
79-
"temperature"
80-
),
81-
GenAIAttributes.GEN_AI_REQUEST_TOP_P: generation_config.get("top_p"),
82-
GenAIAttributes.GEN_AI_REQUEST_MAX_TOKENS: generation_config.get(
83-
"max_output_tokens"
84-
),
85-
GenAIAttributes.GEN_AI_REQUEST_PRESENCE_PENALTY: generation_config.get(
86-
"presence_penalty"
87-
),
88-
GenAIAttributes.GEN_AI_REQUEST_FREQUENCY_PENALTY: generation_config.get(
89-
"frequency_penalty"
90-
),
91-
GenAIAttributes.GEN_AI_OPENAI_REQUEST_SEED: generation_config.get(
92-
"seed"
93-
),
94-
GenAIAttributes.GEN_AI_REQUEST_STOP_SEQUENCES: generation_config.get(
95-
"stop_sequences"
96-
),
9758
}
9859

99-
# filter out None values
100-
return {k: v for k, v in attributes.items() if v is not None}
101-
102-
103-
def _get_generation_config(
104-
instance: _GenerativeModel,
105-
params: GenerateContentParams,
106-
) -> GenerationConfigDict:
107-
generation_config = params.generation_config or instance._generation_config
108-
if generation_config is None:
109-
return {}
110-
if isinstance(generation_config, dict):
111-
return cast(GenerationConfigDict, generation_config)
112-
return cast(GenerationConfigDict, generation_config.to_dict())
113-
114-
115-
_RESOURCE_PREFIX = "publishers/google/models/"
116-
60+
if not generation_config:
61+
return attributes
62+
63+
# Check for optional fields
64+
# https://proto-plus-python.readthedocs.io/en/stable/fields.html#optional-fields
65+
if "temperature" in generation_config:
66+
attributes[GenAIAttributes.GEN_AI_REQUEST_TEMPERATURE] = (
67+
generation_config.temperature
68+
)
69+
if "top_p" in generation_config:
70+
attributes[GenAIAttributes.GEN_AI_REQUEST_TOP_P] = (
71+
generation_config.top_p
72+
)
73+
if "max_output_tokens" in generation_config:
74+
attributes[GenAIAttributes.GEN_AI_REQUEST_MAX_TOKENS] = (
75+
generation_config.max_output_tokens
76+
)
77+
if "presence_penalty" in generation_config:
78+
attributes[GenAIAttributes.GEN_AI_REQUEST_PRESENCE_PENALTY] = (
79+
generation_config.presence_penalty
80+
)
81+
if "frequency_penalty" in generation_config:
82+
attributes[GenAIAttributes.GEN_AI_REQUEST_FREQUENCY_PENALTY] = (
83+
generation_config.frequency_penalty
84+
)
85+
if "seed" in generation_config:
86+
attributes[GenAIAttributes.GEN_AI_OPENAI_REQUEST_SEED] = (
87+
generation_config.seed
88+
)
89+
if "stop_sequences" in generation_config:
90+
attributes[GenAIAttributes.GEN_AI_REQUEST_STOP_SEQUENCES] = (
91+
generation_config.stop_sequences
92+
)
93+
94+
return attributes
95+
96+
97+
_MODEL_STRIP_RE = re.compile(
98+
r"^projects/(.*)/locations/(.*)/publishers/google/models/"
99+
)
117100

118-
def _get_model_name(instance: _GenerativeModel) -> str:
119-
model_name = instance._model_name
120101

121-
# Can use str.removeprefix() once 3.8 is dropped
122-
if model_name.startswith(_RESOURCE_PREFIX):
123-
model_name = model_name[len(_RESOURCE_PREFIX) :]
124-
return model_name
102+
def _get_model_name(model: str) -> str:
103+
return _MODEL_STRIP_RE.sub("", model)
125104

126105

127106
OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT = (

instrumentation-genai/opentelemetry-instrumentation-vertexai/tests/cassettes/test_chat_completion_extra_client_level_params.yaml

Lines changed: 0 additions & 82 deletions
This file was deleted.
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ interactions:
5555
]
5656
},
5757
"finishReason": 2,
58-
"avgLogprobs": -0.006723951548337936
58+
"avgLogprobs": -0.006721805781126022
5959
}
6060
],
6161
"usageMetadata": {

0 commit comments

Comments
 (0)