|
16 | 16 | """Classes for working with vision models."""
|
17 | 17 |
|
18 | 18 | import base64
|
| 19 | +import collections |
19 | 20 | import dataclasses
|
20 | 21 | import hashlib
|
21 | 22 | import io
|
22 | 23 | import json
|
23 | 24 | import pathlib
|
24 | 25 | import typing
|
25 | 26 | from typing import Any, Dict, List, Literal, Optional, Union
|
26 |
| -import urllib |
| 27 | + |
| 28 | +from google.protobuf import struct_pb2 |
| 29 | + |
| 30 | +from proto.marshal.collections import maps |
| 31 | +from proto.marshal.collections import repeated |
| 32 | + |
27 | 33 |
|
28 | 34 | # pylint: disable=g-import-not-at-top
|
29 | 35 | try:
|
|
37 | 43 | PIL_Image = None
|
38 | 44 |
|
39 | 45 |
|
| 46 | +def to_value(value) -> struct_pb2.Value: |
| 47 | + """Return a protobuf Value object representing this value.""" |
| 48 | + if isinstance(value, struct_pb2.Value): |
| 49 | + return value |
| 50 | + if value is None: |
| 51 | + return struct_pb2.Value(null_value=0) |
| 52 | + if isinstance(value, bool): |
| 53 | + return struct_pb2.Value(bool_value=value) |
| 54 | + if isinstance(value, (int, float)): |
| 55 | + return struct_pb2.Value(number_value=float(value)) |
| 56 | + if isinstance(value, str): |
| 57 | + return struct_pb2.Value(string_value=value) |
| 58 | + if isinstance(value, collections.abc.Sequence): |
| 59 | + return struct_pb2.Value(list_value=to_list_value(value)) |
| 60 | + if isinstance(value, collections.abc.Mapping): |
| 61 | + return struct_pb2.Value(struct_value=to_mapping_value(value)) |
| 62 | + raise ValueError("Unable to coerce value: %r" % value) |
| 63 | + |
| 64 | +def to_list_value(value) -> struct_pb2.ListValue: |
| 65 | + # We got a proto, or else something we sent originally. |
| 66 | + # Preserve the instance we have. |
| 67 | + if isinstance(value, struct_pb2.ListValue): |
| 68 | + return value |
| 69 | + if isinstance(value, repeated.RepeatedComposite): |
| 70 | + return struct_pb2.ListValue(values=[v for v in value.pb]) |
| 71 | + |
| 72 | + # We got a list (or something list-like); convert it. |
| 73 | + return struct_pb2.ListValue( |
| 74 | + values=[to_value(v) for v in value] |
| 75 | + ) |
| 76 | + |
| 77 | +def to_mapping_value(value) -> struct_pb2.Struct: |
| 78 | + # We got a proto, or else something we sent originally. |
| 79 | + # Preserve the instance we have. |
| 80 | + if isinstance(value, struct_pb2.Struct): |
| 81 | + return value |
| 82 | + if isinstance(value, maps.MapComposite): |
| 83 | + return struct_pb2.Struct( |
| 84 | + fields={k: v for k, v in value.pb.items()}, |
| 85 | + ) |
| 86 | + |
| 87 | + # We got a dict (or something dict-like); convert it. |
| 88 | + return struct_pb2.Struct( |
| 89 | + fields={ |
| 90 | + k: to_value(v) for k, v in value.items() |
| 91 | + } |
| 92 | + ) |
| 93 | + |
| 94 | + |
| 95 | + |
40 | 96 | _SUPPORTED_UPSCALING_SIZES = [2048, 4096]
|
41 | 97 |
|
42 | 98 |
|
@@ -357,7 +413,7 @@ class ID
|
357 | 413 | shared_generation_parameters["person_generation"] = person_generation
|
358 | 414 |
|
359 | 415 | response = self._endpoint.predict(
|
360 |
| - instances=[instance], |
| 416 | + instances=[to_value(instance)], |
361 | 417 | parameters=parameters,
|
362 | 418 | )
|
363 | 419 |
|
@@ -666,7 +722,7 @@ def upscale_image(
|
666 | 722 | ] = output_compression_quality
|
667 | 723 |
|
668 | 724 | response = self._endpoint.predict(
|
669 |
| - instances=[instance], |
| 725 | + instances=[to_value(instance)], |
670 | 726 | parameters=parameters,
|
671 | 727 | )
|
672 | 728 |
|
|
0 commit comments