|
18 | 18 | TYPE_CHECKING, |
19 | 19 | Any, |
20 | 20 | Callable, |
21 | | - Iterable, |
22 | 21 | MutableSequence, |
23 | | - Optional, |
24 | | - Union, |
25 | 22 | ) |
26 | 23 |
|
27 | 24 | from opentelemetry._events import EventLogger |
|
33 | 30 | from opentelemetry.trace import SpanKind, Tracer |
34 | 31 |
|
35 | 32 | if TYPE_CHECKING: |
| 33 | + from google.cloud.aiplatform_v1.services.prediction_service import client |
36 | 34 | from google.cloud.aiplatform_v1.types import ( |
37 | 35 | content, |
38 | 36 | prediction_service, |
39 | 37 | ) |
40 | | - from vertexai.generative_models import ( |
41 | | - GenerationResponse, |
| 38 | + from google.cloud.aiplatform_v1beta1.services.prediction_service import ( |
| 39 | + client as client_v1beta1, |
42 | 40 | ) |
43 | | - from vertexai.generative_models._generative_models import ( |
44 | | - _GenerativeModel, |
| 41 | + from google.cloud.aiplatform_v1beta1.types import ( |
| 42 | + content as content_v1beta1, |
| 43 | + ) |
| 44 | + from google.cloud.aiplatform_v1beta1.types import ( |
| 45 | + prediction_service as prediction_service_v1beta1, |
45 | 46 | ) |
46 | 47 |
|
47 | 48 |
|
48 | 49 | # Use parameter signature from |
49 | 50 | # https://github.com/googleapis/python-aiplatform/blob/v1.76.0/google/cloud/aiplatform_v1/services/prediction_service/client.py#L2088 |
50 | 51 | # to handle named vs positional args robustly |
51 | 52 | def _extract_params( |
52 | | - request: Optional[ |
53 | | - Union[prediction_service.GenerateContentRequest, dict[Any, Any]] |
54 | | - ] = None, |
| 53 | + request: prediction_service.GenerateContentRequest |
| 54 | + | prediction_service_v1beta1.GenerateContentRequest |
| 55 | + | dict[Any, Any] |
| 56 | + | None = None, |
55 | 57 | *, |
56 | | - model: Optional[str] = None, |
57 | | - contents: Optional[MutableSequence[content.Content]] = None, |
| 58 | + model: str | None = None, |
| 59 | + contents: MutableSequence[content.Content] |
| 60 | + | MutableSequence[content_v1beta1.Content] |
| 61 | + | None = None, |
58 | 62 | **_kwargs: Any, |
59 | 63 | ) -> GenerateContentParams: |
60 | 64 | # Request vs the named parameters are mututally exclusive or the RPC will fail |
@@ -86,9 +90,12 @@ def generate_content_create( |
86 | 90 |
|
87 | 91 | def traced_method( |
88 | 92 | wrapped: Callable[ |
89 | | - ..., GenerationResponse | Iterable[GenerationResponse] |
| 93 | + ..., |
| 94 | + prediction_service.GenerateContentResponse |
| 95 | + | prediction_service_v1beta1.GenerateContentResponse, |
90 | 96 | ], |
91 | | - instance: _GenerativeModel, |
| 97 | + instance: client.PredictionServiceClient |
| 98 | + | client_v1beta1.PredictionServiceClient, |
92 | 99 | args: Any, |
93 | 100 | kwargs: Any, |
94 | 101 | ): |
|
0 commit comments