Skip to content

Commit 805a0f4

Browse files
committed
handle instances converversion to Value protos.
Change-Id: Id33f8d2d6a4cffbfb7b0d37955cc800a867a70d5
1 parent a9fa41a commit 805a0f4

File tree

1 file changed

+59
-3
lines changed

1 file changed

+59
-3
lines changed

google/generativeai/vision_models/_vision_models.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,20 @@
1616
"""Classes for working with vision models."""
1717

1818
import base64
19+
import collections
1920
import dataclasses
2021
import hashlib
2122
import io
2223
import json
2324
import pathlib
2425
import typing
2526
from typing import Any, Dict, List, Literal, Optional, Union
26-
import urllib
27+
28+
from google.protobuf import struct_pb2
29+
30+
from proto.marshal.collections import maps
31+
from proto.marshal.collections import repeated
32+
2733

2834
# pylint: disable=g-import-not-at-top
2935
try:
@@ -37,6 +43,56 @@
3743
PIL_Image = None
3844

3945

46+
def to_value(value) -> struct_pb2.Value:
47+
"""Return a protobuf Value object representing this value."""
48+
if isinstance(value, struct_pb2.Value):
49+
return value
50+
if value is None:
51+
return struct_pb2.Value(null_value=0)
52+
if isinstance(value, bool):
53+
return struct_pb2.Value(bool_value=value)
54+
if isinstance(value, (int, float)):
55+
return struct_pb2.Value(number_value=float(value))
56+
if isinstance(value, str):
57+
return struct_pb2.Value(string_value=value)
58+
if isinstance(value, collections.abc.Sequence):
59+
return struct_pb2.Value(list_value=to_list_value(value))
60+
if isinstance(value, collections.abc.Mapping):
61+
return struct_pb2.Value(struct_value=to_mapping_value(value))
62+
raise ValueError("Unable to coerce value: %r" % value)
63+
64+
def to_list_value(value) -> struct_pb2.ListValue:
65+
# We got a proto, or else something we sent originally.
66+
# Preserve the instance we have.
67+
if isinstance(value, struct_pb2.ListValue):
68+
return value
69+
if isinstance(value, repeated.RepeatedComposite):
70+
return struct_pb2.ListValue(values=[v for v in value.pb])
71+
72+
# We got a list (or something list-like); convert it.
73+
return struct_pb2.ListValue(
74+
values=[to_value(v) for v in value]
75+
)
76+
77+
def to_mapping_value(value) -> struct_pb2.Struct:
78+
# We got a proto, or else something we sent originally.
79+
# Preserve the instance we have.
80+
if isinstance(value, struct_pb2.Struct):
81+
return value
82+
if isinstance(value, maps.MapComposite):
83+
return struct_pb2.Struct(
84+
fields={k: v for k, v in value.pb.items()},
85+
)
86+
87+
# We got a dict (or something dict-like); convert it.
88+
return struct_pb2.Struct(
89+
fields={
90+
k: to_value(v) for k, v in value.items()
91+
}
92+
)
93+
94+
95+
4096
_SUPPORTED_UPSCALING_SIZES = [2048, 4096]
4197

4298

@@ -357,7 +413,7 @@ class ID
357413
shared_generation_parameters["person_generation"] = person_generation
358414

359415
response = self._endpoint.predict(
360-
instances=[instance],
416+
instances=[to_value(instance)],
361417
parameters=parameters,
362418
)
363419

@@ -666,7 +722,7 @@ def upscale_image(
666722
] = output_compression_quality
667723

668724
response = self._endpoint.predict(
669-
instances=[instance],
725+
instances=[to_value(instance)],
670726
parameters=parameters,
671727
)
672728

0 commit comments

Comments
 (0)