Skip to content

Commit 59642c2

Browse files
committed
merge
Change-Id: I361c62ffbe3b7b7471f54382cb2b02cd6e5e5e05
2 parents 805a0f4 + 8f7f5cb commit 59642c2

13 files changed

+188
-98
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ The Google AI Python SDK is the easiest way for Python developers to build with
99
## Get started with the Gemini API
1010
1. Go to [Google AI Studio](https://aistudio.google.com/).
1111
2. Login with your Google account.
12-
3. [Create](https://aistudio.google.com/app/apikey) an API key. Note that in Europe the free tier is not available.
12+
3. [Create](https://aistudio.google.com/app/apikey) an API key.
1313
4. Try a Python SDK [quickstart](https://github.com/google-gemini/gemini-api-cookbook/blob/main/quickstarts/Prompting.ipynb) in the [Gemini API Cookbook](https://github.com/google-gemini/gemini-api-cookbook/).
1414
5. For detailed instructions, try the
1515
[Python SDK tutorial](https://ai.google.dev/tutorials/python_quickstart) on [ai.google.dev](https://ai.google.dev).

google/generativeai/client.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
import os
44
import contextlib
5+
import inspect
56
import dataclasses
67
import pathlib
7-
import types
88
from typing import Any, cast
99
from collections.abc import Sequence
1010
import httplib2
@@ -30,6 +30,21 @@
3030
__version__ = "0.0.0"
3131

3232
USER_AGENT = "genai-py"
33+
34+
#### Caution! ####
35+
# - It would make sense for the discovery URL to respect the client_options.endpoint setting.
36+
# - That would make testing Files on the staging server possible.
37+
# - We tried fixing this once, but broke colab in the process because their endpoint didn't forward the discovery
38+
# requests. https://github.com/google-gemini/generative-ai-python/pull/333
39+
# - Kaggle would have a similar problem (b/362278209).
40+
# - I think their proxy would forward the discovery traffic.
41+
# - But they don't need to intercept the files-service at all, and uploads of large files could overload them.
42+
# - Do the scotty uploads go to the same domain?
43+
# - If you do route the discovery call to kaggle, be sure to attach the default_metadata (they need it).
44+
# - One solution to all this would be if configure could take overrides per service.
45+
# - set client_options.endpoint, but use a different endpoint for file service? It's not clear how best to do that
46+
# through the file service.
47+
##################
3348
GENAI_API_DISCOVERY_URL = "https://generativelanguage.googleapis.com/$discovery/rest"
3449

3550

@@ -50,7 +65,7 @@ def __init__(self, *args, **kwargs):
5065
self._discovery_api = None
5166
super().__init__(*args, **kwargs)
5267

53-
def _setup_discovery_api(self):
68+
def _setup_discovery_api(self, metadata: dict | Sequence[tuple[str, str]] = ()):
5469
api_key = self._client_options.api_key
5570
if api_key is None:
5671
raise ValueError(
@@ -61,6 +76,7 @@ def _setup_discovery_api(self):
6176
http=httplib2.Http(),
6277
postproc=lambda resp, content: (resp, content),
6378
uri=f"{GENAI_API_DISCOVERY_URL}?version=v1beta&key={api_key}",
79+
headers=dict(metadata),
6480
)
6581
response, content = request.execute()
6682
request.http.close()
@@ -78,9 +94,10 @@ def create_file(
7894
name: str | None = None,
7995
display_name: str | None = None,
8096
resumable: bool = True,
97+
metadata: Sequence[tuple[str, str]] = (),
8198
) -> protos.File:
8299
if self._discovery_api is None:
83-
self._setup_discovery_api()
100+
self._setup_discovery_api(metadata)
84101

85102
file = {}
86103
if name is not None:
@@ -92,6 +109,8 @@ def create_file(
92109
filename=path, mimetype=mime_type, resumable=resumable
93110
)
94111
request = self._discovery_api.media().upload(body={"file": file}, media_body=media)
112+
for key, value in metadata:
113+
request.headers[key] = value
95114
result = request.execute()
96115

97116
return self.get_file({"name": result["file"]["name"]})
@@ -226,16 +245,14 @@ def make_client(self, name):
226245
def keep(name, f):
227246
if name.startswith("_"):
228247
return False
229-
elif name == "create_file":
230-
return False
231-
elif not isinstance(f, types.FunctionType):
232-
return False
233-
elif isinstance(f, classmethod):
248+
249+
if not callable(f):
234250
return False
235-
elif isinstance(f, staticmethod):
251+
252+
if "metadata" not in inspect.signature(f).parameters.keys():
236253
return False
237-
else:
238-
return True
254+
255+
return True
239256

240257
def add_default_metadata_wrapper(f):
241258
def call(*args, metadata=(), **kwargs):
@@ -244,7 +261,7 @@ def call(*args, metadata=(), **kwargs):
244261

245262
return call
246263

247-
for name, value in cls.__dict__.items():
264+
for name, value in inspect.getmembers(cls):
248265
if not keep(name, value):
249266
continue
250267
f = getattr(client, name)

google/generativeai/types/generation_types.py

Lines changed: 84 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -144,17 +144,27 @@ class GenerationConfig:
144144
Note: The default value varies by model, see the
145145
`Model.top_k` attribute of the `Model` returned the
146146
`genai.get_model` function.
147-
147+
seed:
148+
Optional. Seed used in decoding. If not set, the request uses a randomly generated seed.
148149
response_mime_type:
149150
Optional. Output response mimetype of the generated candidate text.
150151
151152
Supported mimetype:
152153
`text/plain`: (default) Text output.
154+
`text/x-enum`: for use with a string-enum in `response_schema`
153155
`application/json`: JSON response in the candidates.
154156
155157
response_schema:
156158
Optional. Specifies the format of the JSON requested if response_mime_type is
157159
`application/json`.
160+
presence_penalty:
161+
Optional.
162+
frequency_penalty:
163+
Optional.
164+
response_logprobs:
165+
Optional. If true, export the `logprobs` results in response.
166+
logprobs:
167+
Optional. Number of candidates of log probabilities to return at each step of decoding.
158168
"""
159169

160170
candidate_count: int | None = None
@@ -163,8 +173,13 @@ class GenerationConfig:
163173
temperature: float | None = None
164174
top_p: float | None = None
165175
top_k: int | None = None
176+
seed: int | None = None
166177
response_mime_type: str | None = None
167178
response_schema: protos.Schema | Mapping[str, Any] | type | None = None
179+
presence_penalty: float | None = None
180+
frequency_penalty: float | None = None
181+
response_logprobs: bool | None = None
182+
logprobs: int | None = None
168183

169184

170185
GenerationConfigType = Union[protos.GenerationConfig, GenerationConfigDict, GenerationConfig]
@@ -306,6 +321,7 @@ def _join_code_execution_result(result_1, result_2):
306321

307322

308323
def _join_candidates(candidates: Iterable[protos.Candidate]):
324+
"""Joins stream chunks of a single candidate."""
309325
candidates = tuple(candidates)
310326

311327
index = candidates[0].index # These should all be the same.
@@ -321,6 +337,7 @@ def _join_candidates(candidates: Iterable[protos.Candidate]):
321337

322338

323339
def _join_candidate_lists(candidate_lists: Iterable[list[protos.Candidate]]):
340+
"""Joins stream chunks where each chunk is a list of candidate chunks."""
324341
# Assuming that is a candidate ends, it is no longer returned in the list of
325342
# candidates and that's why candidates have an index
326343
candidates = collections.defaultdict(list)
@@ -344,10 +361,15 @@ def _join_prompt_feedbacks(
344361

345362
def _join_chunks(chunks: Iterable[protos.GenerateContentResponse]):
346363
chunks = tuple(chunks)
364+
if "usage_metadata" in chunks[-1]:
365+
usage_metadata = chunks[-1].usage_metadata
366+
else:
367+
usage_metadata = None
368+
347369
return protos.GenerateContentResponse(
348370
candidates=_join_candidate_lists(c.candidates for c in chunks),
349371
prompt_feedback=_join_prompt_feedbacks(c.prompt_feedback for c in chunks),
350-
usage_metadata=chunks[-1].usage_metadata,
372+
usage_metadata=usage_metadata,
351373
)
352374

353375

@@ -412,14 +434,22 @@ def parts(self):
412434
"""
413435
candidates = self.candidates
414436
if not candidates:
415-
raise ValueError(
437+
msg = (
416438
"Invalid operation: The `response.parts` quick accessor requires a single candidate, "
417-
"but none were returned. Please check the `response.prompt_feedback` to determine if the prompt was blocked."
439+
"but but `response.candidates` is empty."
418440
)
441+
if self.prompt_feedback:
442+
raise ValueError(
443+
msg + "\nThis appears to be caused by a blocked prompt, "
444+
f"see `response.prompt_feedback`: {self.prompt_feedback}"
445+
)
446+
else:
447+
raise ValueError(msg)
448+
419449
if len(candidates) > 1:
420450
raise ValueError(
421-
"Invalid operation: The `response.parts` quick accessor requires a single candidate. "
422-
"For multiple candidates, please use `result.candidates[index].text`."
451+
"Invalid operation: The `response.parts` quick accessor retrieves the parts for a single candidate. "
452+
"This response contains multiple candidates, please use `result.candidates[index].text`."
423453
)
424454
parts = candidates[0].content.parts
425455
return parts
@@ -433,10 +463,53 @@ def text(self):
433463
"""
434464
parts = self.parts
435465
if not parts:
436-
raise ValueError(
437-
"Invalid operation: The `response.text` quick accessor requires the response to contain a valid `Part`, "
438-
"but none were returned. Please check the `candidate.safety_ratings` to determine if the response was blocked."
466+
candidate = self.candidates[0]
467+
468+
fr = candidate.finish_reason
469+
FinishReason = protos.Candidate.FinishReason
470+
471+
msg = (
472+
"Invalid operation: The `response.text` quick accessor requires the response to contain a valid "
473+
"`Part`, but none were returned. The candidate's "
474+
f"[finish_reason](https://ai.google.dev/api/generate-content#finishreason) is {fr}."
439475
)
476+
if candidate.finish_message:
477+
msg += 'The `finish_message` is "{candidate.finish_message}".'
478+
479+
if fr is FinishReason.FINISH_REASON_UNSPECIFIED:
480+
raise ValueError(msg)
481+
elif fr is FinishReason.STOP:
482+
raise ValueError(msg)
483+
elif fr is FinishReason.MAX_TOKENS:
484+
raise ValueError(msg)
485+
elif fr is FinishReason.SAFETY:
486+
raise ValueError(
487+
msg + f" The candidate's safety_ratings are: {candidate.safety_ratings}.",
488+
candidate.safety_ratings,
489+
)
490+
elif fr is FinishReason.RECITATION:
491+
raise ValueError(
492+
msg + " Meaning that the model was reciting from copyrighted material."
493+
)
494+
elif fr is FinishReason.LANGUAGE:
495+
raise ValueError(msg + " Meaning the response was using an unsupported language.")
496+
elif fr is FinishReason.OTHER:
497+
raise ValueError(msg)
498+
elif fr is FinishReason.BLOCKLIST:
499+
raise ValueError(msg)
500+
elif fr is FinishReason.PROHIBITED_CONTENT:
501+
raise ValueError(msg)
502+
elif fr is FinishReason.SPII:
503+
raise ValueError(msg + " SPII - Sensitive Personally Identifiable Information.")
504+
elif fr is FinishReason.MALFORMED_FUNCTION_CALL:
505+
raise ValueError(
506+
msg + " Meaning that model generated a `FunctionCall` that was invalid. "
507+
"Setting the "
508+
"[Function calling mode](https://ai.google.dev/gemini-api/docs/function-calling#function_calling_mode) "
509+
"to `ANY` can fix this because it enables constrained decoding."
510+
)
511+
else:
512+
raise ValueError(msg)
440513

441514
texts = []
442515
for part in parts:
@@ -490,7 +563,8 @@ def __str__(self) -> str:
490563
_result = _result.replace("\n", "\n ")
491564

492565
if self._error:
493-
_error = f",\nerror=<{self._error.__class__.__name__}> {self._error}"
566+
567+
_error = f",\nerror={repr(self._error)}"
494568
else:
495569
_error = ""
496570

google/generativeai/types/model_types.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,9 @@ def idecode_time(parent: dict["str", Any], name: str):
143143

144144
def decode_tuned_model(tuned_model: protos.TunedModel | dict["str", Any]) -> TunedModel:
145145
if isinstance(tuned_model, protos.TunedModel):
146-
tuned_model = type(tuned_model).to_dict(tuned_model) # pytype: disable=attribute-error
146+
tuned_model = type(tuned_model).to_dict(
147+
tuned_model, including_default_value_fields=False
148+
) # pytype: disable=attribute-error
147149
tuned_model["state"] = to_tuned_model_state(tuned_model.pop("state", None))
148150

149151
base_model = tuned_model.pop("base_model", None)
@@ -195,6 +197,7 @@ class TunedModel:
195197
create_time: datetime.datetime | None = None
196198
update_time: datetime.datetime | None = None
197199
tuning_task: TuningTask | None = None
200+
reader_project_numbers: list[int] | None = None
198201

199202
@property
200203
def permissions(self) -> permission_types.Permissions:

google/generativeai/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17-
__version__ = "0.7.2"
17+
__version__ = "0.8.2"

samples/controlled_generation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ class Choice(enum.Enum):
139139
response_mime_type="text/x.enum", response_schema=Choice
140140
),
141141
)
142-
print(result) # "Keyboard"
142+
print(result) # Keyboard
143143
# [END x_enum]
144144

145145
def test_x_enum_raw(self):
@@ -157,7 +157,7 @@ def test_x_enum_raw(self):
157157
},
158158
),
159159
)
160-
print(result) # "Keyboard"
160+
print(result) # Keyboard
161161
# [END x_enum_raw]
162162

163163

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def get_version():
4242
release_status = "Development Status :: 5 - Production/Stable"
4343

4444
dependencies = [
45-
"google-ai-generativelanguage==0.6.6",
45+
"google-ai-generativelanguage==0.6.10",
4646
"google-api-core",
4747
"google-api-python-client",
4848
"google-auth>=2.15.0", # 2.15 adds API key auth support

tests/test_files.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,14 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
from __future__ import annotations
1516

1617
from google.generativeai.types import file_types
1718

1819
import collections
1920
import datetime
2021
import os
21-
from typing import Iterable, Union
22+
from typing import Iterable, Sequence
2223
import pathlib
2324

2425
import google
@@ -37,12 +38,13 @@ def __init__(self, test):
3738

3839
def create_file(
3940
self,
40-
path: Union[str, pathlib.Path, os.PathLike],
41+
path: str | pathlib.Path | os.PathLike,
4142
*,
42-
mime_type: Union[str, None] = None,
43-
name: Union[str, None] = None,
44-
display_name: Union[str, None] = None,
43+
mime_type: str | None = None,
44+
name: str | None = None,
45+
display_name: str | None = None,
4546
resumable: bool = True,
47+
metadata: Sequence[tuple[str, str]] = (),
4648
) -> protos.File:
4749
self.observed_requests.append(
4850
dict(

0 commit comments

Comments
 (0)