|
3 | 3 | import os
|
4 | 4 | import contextlib
|
5 | 5 | import inspect
|
| 6 | +import collections |
6 | 7 | import dataclasses
|
7 | 8 | import pathlib
|
8 | 9 | from typing import Any, cast
|
9 |
| -from collections.abc import Sequence |
| 10 | +from collections.abc import Sequence, Mapping |
10 | 11 | import httplib2
|
11 | 12 | from io import IOBase
|
12 | 13 |
|
|
23 | 24 | import googleapiclient.http
|
24 | 25 | import googleapiclient.discovery
|
25 | 26 |
|
| 27 | +from google.protobuf import struct_pb2 |
| 28 | + |
| 29 | +from proto.marshal.collections import maps |
| 30 | +from proto.marshal.collections import repeated |
| 31 | + |
26 | 32 | try:
|
27 | 33 | from google.generativeai import version
|
28 | 34 |
|
@@ -130,6 +136,70 @@ async def create_file(self, *args, **kwargs):
|
130 | 136 | )
|
131 | 137 |
|
132 | 138 |
|
| 139 | +# This is to get around https://github.com/googleapis/proto-plus-python/issues/488 |
| 140 | +def to_value(value) -> struct_pb2.Value: |
| 141 | + """Return a protobuf Value object representing this value.""" |
| 142 | + if isinstance(value, struct_pb2.Value): |
| 143 | + return value |
| 144 | + if value is None: |
| 145 | + return struct_pb2.Value(null_value=0) |
| 146 | + if isinstance(value, bool): |
| 147 | + return struct_pb2.Value(bool_value=value) |
| 148 | + if isinstance(value, (int, float)): |
| 149 | + return struct_pb2.Value(number_value=float(value)) |
| 150 | + if isinstance(value, str): |
| 151 | + return struct_pb2.Value(string_value=value) |
| 152 | + if isinstance(value, collections.abc.Sequence): |
| 153 | + return struct_pb2.Value(list_value=to_list_value(value)) |
| 154 | + if isinstance(value, collections.abc.Mapping): |
| 155 | + return struct_pb2.Value(struct_value=to_mapping_value(value)) |
| 156 | + raise ValueError("Unable to coerce value: %r" % value) |
| 157 | + |
| 158 | + |
| 159 | +def to_list_value(value) -> struct_pb2.ListValue: |
| 160 | + # We got a proto, or else something we sent originally. |
| 161 | + # Preserve the instance we have. |
| 162 | + if isinstance(value, struct_pb2.ListValue): |
| 163 | + return value |
| 164 | + if isinstance(value, repeated.RepeatedComposite): |
| 165 | + return struct_pb2.ListValue(values=[v for v in value.pb]) |
| 166 | + |
| 167 | + # We got a list (or something list-like); convert it. |
| 168 | + return struct_pb2.ListValue(values=[to_value(v) for v in value]) |
| 169 | + |
| 170 | + |
| 171 | +def to_mapping_value(value) -> struct_pb2.Struct: |
| 172 | + # We got a proto, or else something we sent originally. |
| 173 | + # Preserve the instance we have. |
| 174 | + if isinstance(value, struct_pb2.Struct): |
| 175 | + return value |
| 176 | + if isinstance(value, maps.MapComposite): |
| 177 | + return struct_pb2.Struct( |
| 178 | + fields={k: v for k, v in value.pb.items()}, |
| 179 | + ) |
| 180 | + |
| 181 | + # We got a dict (or something dict-like); convert it. |
| 182 | + return struct_pb2.Struct(fields={k: to_value(v) for k, v in value.items()}) |
| 183 | + |
| 184 | + |
| 185 | +class PredictionServiceClient(glm.PredictionServiceClient): |
| 186 | + def predict(self, model=None, instances=None, parameters=None): |
| 187 | + pr = protos.PredictRequest.pb() |
| 188 | + request = pr( |
| 189 | + model=model, instances=[to_value(i) for i in instances], parameters=to_value(parameters) |
| 190 | + ) |
| 191 | + return super().predict(request) |
| 192 | + |
| 193 | + |
| 194 | +class PredictionServiceAsyncClient(glm.PredictionServiceAsyncClient): |
| 195 | + async def predict(self, model=None, instances=None, parameters=None): |
| 196 | + pr = protos.PredictRequest.pb() |
| 197 | + request = pr( |
| 198 | + model=model, instances=[to_value(i) for i in instances], parameters=to_value(parameters) |
| 199 | + ) |
| 200 | + return await super().predict(request) |
| 201 | + |
| 202 | + |
133 | 203 | @dataclasses.dataclass
|
134 | 204 | class _ClientManager:
|
135 | 205 | client_config: dict[str, Any] = dataclasses.field(default_factory=dict)
|
@@ -220,15 +290,20 @@ def configure(
|
220 | 290 | self.clients = {}
|
221 | 291 |
|
222 | 292 | def make_client(self, name):
|
223 |
| - if name == "file": |
224 |
| - cls = FileServiceClient |
225 |
| - elif name == "file_async": |
226 |
| - cls = FileServiceAsyncClient |
227 |
| - elif name.endswith("_async"): |
228 |
| - name = name.split("_")[0] |
229 |
| - cls = getattr(glm, name.title() + "ServiceAsyncClient") |
230 |
| - else: |
231 |
| - cls = getattr(glm, name.title() + "ServiceClient") |
| 293 | + local_clients = { |
| 294 | + "file": FileServiceClient, |
| 295 | + "file_async": FileServiceAsyncClient, |
| 296 | + "prediction": PredictionServiceClient, |
| 297 | + "prediction_async": PredictionServiceAsyncClient, |
| 298 | + } |
| 299 | + cls = local_clients.get(name, None) |
| 300 | + |
| 301 | + if cls is None: |
| 302 | + if name.endswith("_async"): |
| 303 | + name = name.split("_")[0] |
| 304 | + cls = getattr(glm, name.title() + "ServiceAsyncClient") |
| 305 | + else: |
| 306 | + cls = getattr(glm, name.title() + "ServiceClient") |
232 | 307 |
|
233 | 308 | # Attempt to configure using defaults.
|
234 | 309 | if not self.client_config:
|
|
0 commit comments