Skip to content

Commit 0988543

Browse files
MarkDaoustshilpakancharlamarkmcd
authored
Update to include the generative service. (#105)
* Start code for generative service. This is working with v1beta. Change-Id: I96d2d4f773db8e4be621b40697016d1acbe24903 * Embedding content functionality for v1beta. Change-Id: I1bf2379b33a607b0bb2cf77066073166a6fe9f95 * Add py3.9 support, Fix roles, UsageMetadata, and more. Change-Id: I5fa91f191eb5eebb0fd160325b102bfb48ef8e27 * Add async suppoort. Fix some types. Add count_tokens. Change-Id: Ic03d9caa9996843c0ac1438f52c8b10a08fd6563 * Docstrings Change-Id: Ia2adca04a2f65bfd4ea40eec0cc18c4c98703b9f * Add missing async type to export list. Change-Id: I22e0aba1a59997fef6263105413b5626f1cc5a51 * Add async tests, kwargs, format. Change-Id: Ia2c3260efe58f48183458e997ba560cc07f4b442 * docs Change-Id: I45d2f405fb058f25fc99c30c79221f6461ebe945 * debug tests * Add GenerationConfig at the top level * test * replace __init__.py * remove -e * drop notebook tests for now * Update version. * format + pytype * Fix pytype. * Fix tests + pytype --------- Co-authored-by: Shilpa Kancharla <[email protected]> Co-authored-by: Mark McDonald <[email protected]>
1 parent 5b0b406 commit 0988543

25 files changed

+3008
-142
lines changed

.github/workflows/test_pr.yaml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ jobs:
2323
- name: Run tests
2424
run: |
2525
python --version
26-
pip install -q -e .[dev]
27-
python -m unittest discover --pattern '*test*.py'
26+
pip install .[dev]
27+
python -m unittest
2828
test3_10:
2929
name: Test Py3.10
3030
runs-on: ubuntu-latest
@@ -36,8 +36,8 @@ jobs:
3636
- name: Run tests
3737
run: |
3838
python --version
39-
pip install -q -e .[dev]
40-
python -m unittest discover --pattern '*test*.py'
39+
pip install -q .[dev]
40+
python -m unittest
4141
test3_9:
4242
name: Test Py3.9
4343
runs-on: ubuntu-latest
@@ -49,8 +49,8 @@ jobs:
4949
- name: Run tests
5050
run: |
5151
python --version
52-
pip install -q -e .[dev]
53-
python -m unittest discover --pattern '*test*.py'
52+
pip install .[dev]
53+
python -m unittest
5454
pytype3_10:
5555
name: pytype 3.10
5656
runs-on: ubuntu-latest
@@ -62,7 +62,7 @@ jobs:
6262
- name: Run pytype
6363
run: |
6464
python --version
65-
pip install -q -e .[dev]
65+
pip install .[dev]
6666
pip install -q gspread ipython
6767
pytype
6868
format:
@@ -76,7 +76,7 @@ jobs:
7676
- name: Check format
7777
run: |
7878
python --version
79-
pip install -q -e .
79+
pip install -q .
8080
pip install -q black
8181
black . --check
8282

google/generativeai/__init__.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,10 @@
4242
Use the `palm.chat` function to have a discussion with a model:
4343
4444
```
45-
response = palm.chat(messages=["Hello."])
46-
print(response.last) # 'Hello! What can I help you with?'
47-
response.reply("Can you tell me a joke?")
45+
chat = palm.chat(messages=["Hello."])
46+
print(chat.last) # 'Hello! What can I help you with?'
47+
chat = chat.reply("Can you tell me a joke?")
48+
print(chat.last) # 'Why did the chicken cross the road?'
4849
```
4950
5051
## Models
@@ -68,13 +69,20 @@
6869
"""
6970
from __future__ import annotations
7071

71-
from google.generativeai import types
7272
from google.generativeai import version
7373

74+
from google.generativeai import types
75+
from google.generativeai.types import GenerationConfig
76+
77+
7478
from google.generativeai.discuss import chat
7579
from google.generativeai.discuss import chat_async
7680
from google.generativeai.discuss import count_message_tokens
7781

82+
from google.generativeai.embedding import embed_content
83+
84+
from google.generativeai.generative_models import GenerativeModel
85+
7886
from google.generativeai.text import generate_text
7987
from google.generativeai.text import generate_embeddings
8088
from google.generativeai.text import count_text_tokens

google/generativeai/client.py

Lines changed: 54 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,3 @@
1-
# -*- coding: utf-8 -*-
2-
# Copyright 2023 Google LLC
3-
#
4-
# Licensed under the Apache License, Version 2.0 (the "License");
5-
# you may not use this file except in compliance with the License.
6-
# You may obtain a copy of the License at
7-
#
8-
# http://www.apache.org/licenses/LICENSE-2.0
9-
#
10-
# Unless required by applicable law or agreed to in writing, software
11-
# distributed under the License is distributed on an "AS IS" BASIS,
12-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13-
# See the License for the specific language governing permissions and
14-
# limitations under the License.
151
from __future__ import annotations
162

173
import os
@@ -27,7 +13,12 @@
2713
from google.api_core import gapic_v1
2814
from google.api_core import operations_v1
2915

30-
from google.generativeai import version
16+
try:
17+
from google.generativeai import version
18+
19+
__version__ = version.__version__
20+
except ImportError:
21+
__version__ = "0.0.0"
3122

3223
USER_AGENT = "genai-py"
3324

@@ -36,11 +27,10 @@
3627
class _ClientManager:
3728
client_config: dict[str, Any] = dataclasses.field(default_factory=dict)
3829
default_metadata: Sequence[tuple[str, str]] = ()
30+
3931
discuss_client: glm.DiscussServiceClient | None = None
4032
discuss_async_client: glm.DiscussServiceAsyncClient | None = None
41-
model_client: glm.ModelServiceClient | None = None
42-
text_client: glm.TextServiceClient | None = None
43-
operations_client = None
33+
clients: dict[str, Any] = dataclasses.field(default_factory=dict)
4434

4535
def configure(
4636
self,
@@ -54,7 +44,7 @@ def configure(
5444
# We could accept a dict since all the `Transport` classes take the same args,
5545
# but that seems rare. Users that need it can just switch to the low level API.
5646
transport: str | None = None,
57-
client_options: client_options_lib.ClientOptions | dict | None = None,
47+
client_options: client_options_lib.ClientOptions | dict[str, Any] | None = None,
5848
client_info: gapic_v1.client_info.ClientInfo | None = None,
5949
default_metadata: Sequence[tuple[str, str]] = (),
6050
) -> None:
@@ -93,7 +83,7 @@ def configure(
9383

9484
client_options.api_key = api_key
9585

96-
user_agent = f"{USER_AGENT}/{version.__version__}"
86+
user_agent = f"{USER_AGENT}/{__version__}"
9787
if client_info:
9888
# Be respectful of any existing agent setting.
9989
if client_info.user_agent:
@@ -114,12 +104,16 @@ def configure(
114104

115105
self.client_config = client_config
116106
self.default_metadata = default_metadata
117-
self.discuss_client = None
118-
self.text_client = None
119-
self.model_client = None
120-
self.operations_client = None
121107

122-
def make_client(self, cls):
108+
self.clients = {}
109+
110+
def make_client(self, name):
111+
if name.endswith("_async"):
112+
name = name.split("_")[0]
113+
cls = getattr(glm, name.title() + "ServiceAsyncClient")
114+
else:
115+
cls = getattr(glm, name.title() + "ServiceClient")
116+
123117
# Attempt to configure using defaults.
124118
if not self.client_config:
125119
configure()
@@ -157,35 +151,25 @@ def call(*args, metadata=(), **kwargs):
157151

158152
return client
159153

160-
def get_default_discuss_client(self) -> glm.DiscussServiceClient:
161-
if self.discuss_client is None:
162-
self.discuss_client = self.make_client(glm.DiscussServiceClient)
163-
return self.discuss_client
164-
165-
def get_default_text_client(self) -> glm.TextServiceClient:
166-
if self.text_client is None:
167-
self.text_client = self.make_client(glm.TextServiceClient)
168-
return self.text_client
169-
170-
def get_default_discuss_async_client(self) -> glm.DiscussServiceAsyncClient:
171-
if self.discuss_async_client is None:
172-
self.discuss_async_client = self.make_client(glm.DiscussServiceAsyncClient)
173-
return self.discuss_async_client
154+
def get_default_client(self, name):
155+
name = name.lower()
156+
if name == "operations":
157+
return self.get_default_operations_client()
174158

175-
def get_default_model_client(self) -> glm.ModelServiceClient:
176-
if self.model_client is None:
177-
self.model_client = self.make_client(glm.ModelServiceClient)
178-
return self.model_client
159+
client = self.clients.get(name)
160+
if client is None:
161+
client = self.make_client(name)
162+
self.clients[name] = client
163+
return client
179164

180165
def get_default_operations_client(self) -> operations_v1.OperationsClient:
181-
if self.operations_client is None:
182-
self.model_client = get_default_model_client()
183-
self.operations_client = self.model_client._transport.operations_client
184-
185-
return self.operations_client
186-
166+
client = self.clients.get("operations", None)
167+
if client is None:
168+
model_client = self.get_default_client("Model")
169+
client = model_client._transport.operations_client
170+
self.clients["operations"] = client
187171

188-
_client_manager = _ClientManager()
172+
return client
189173

190174

191175
def configure(
@@ -230,21 +214,33 @@ def configure(
230214
)
231215

232216

217+
_client_manager = _ClientManager()
218+
_client_manager.configure()
219+
220+
233221
def get_default_discuss_client() -> glm.DiscussServiceClient:
234-
return _client_manager.get_default_discuss_client()
222+
return _client_manager.get_default_client("discuss")
235223

236224

237-
def get_default_text_client() -> glm.TextServiceClient:
238-
return _client_manager.get_default_text_client()
225+
def get_default_discuss_async_client() -> glm.DiscussServiceAsyncClient:
226+
return _client_manager.get_default_client("discuss_async")
239227

240228

241-
def get_default_operations_client() -> operations_v1.OperationsClient:
242-
return _client_manager.get_default_operations_client()
229+
def get_default_generative_client() -> glm.GenerativeServiceClient:
230+
return _client_manager.get_default_client("generative")
243231

244232

245-
def get_default_discuss_async_client() -> glm.DiscussServiceAsyncClient:
246-
return _client_manager.get_default_discuss_async_client()
233+
def get_default_generative_async_client() -> glm.GenerativeServiceAsyncClient:
234+
return _client_manager.get_default_client("generative_async")
235+
236+
237+
def get_default_text_client() -> glm.TextServiceClient:
238+
return _client_manager.get_default_client("text")
239+
240+
241+
def get_default_operations_client() -> operations_v1.OperationsClient:
242+
return _client_manager.get_default_client("operations")
247243

248244

249245
def get_default_model_client() -> glm.ModelServiceAsyncClient:
250-
return _client_manager.get_default_model_client()
246+
return _client_manager.get_default_client("model")

google/generativeai/discuss.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -301,16 +301,6 @@ def _make_generate_message_request(
301301
)
302302

303303

304-
def set_doc(doc):
305-
"""A decorator to set the docstring of a function."""
306-
307-
def inner(f):
308-
f.__doc__ = doc
309-
return f
310-
311-
return inner
312-
313-
314304
DEFAULT_DISCUSS_MODEL = "models/chat-bison-001"
315305

316306

@@ -411,7 +401,7 @@ def chat(
411401
return _generate_response(client=client, request=request)
412402

413403

414-
@set_doc(chat.__doc__)
404+
@string_utils.set_doc(chat.__doc__)
415405
async def chat_async(
416406
*,
417407
model: model_types.AnyModelNameOptions | None = "models/chat-bison-001",
@@ -447,7 +437,7 @@ async def chat_async(
447437

448438

449439
@string_utils.prettyprint
450-
@set_doc(discuss_types.ChatResponse.__doc__)
440+
@string_utils.set_doc(discuss_types.ChatResponse.__doc__)
451441
@dataclasses.dataclass(**DATACLASS_KWARGS, init=False)
452442
class ChatResponse(discuss_types.ChatResponse):
453443
_client: glm.DiscussServiceClient | None = dataclasses.field(default=lambda: None, repr=False)
@@ -457,7 +447,7 @@ def __init__(self, **kwargs):
457447
setattr(self, key, value)
458448

459449
@property
460-
@set_doc(discuss_types.ChatResponse.last.__doc__)
450+
@string_utils.set_doc(discuss_types.ChatResponse.last.__doc__)
461451
def last(self) -> str | None:
462452
if self.messages[-1]:
463453
return self.messages[-1]["content"]
@@ -470,7 +460,7 @@ def last(self, message: discuss_types.MessageOptions):
470460
message = type(message).to_dict(message)
471461
self.messages[-1] = message
472462

473-
@set_doc(discuss_types.ChatResponse.reply.__doc__)
463+
@string_utils.set_doc(discuss_types.ChatResponse.reply.__doc__)
474464
def reply(self, message: discuss_types.MessageOptions) -> discuss_types.ChatResponse:
475465
if isinstance(self._client, glm.DiscussServiceAsyncClient):
476466
raise TypeError(f"reply can't be called on an async client, use reply_async instead.")
@@ -489,7 +479,7 @@ def reply(self, message: discuss_types.MessageOptions) -> discuss_types.ChatResp
489479
request = _make_generate_message_request(**request)
490480
return _generate_response(request=request, client=self._client)
491481

492-
@set_doc(discuss_types.ChatResponse.reply.__doc__)
482+
@string_utils.set_doc(discuss_types.ChatResponse.reply.__doc__)
493483
async def reply_async(
494484
self, message: discuss_types.MessageOptions
495485
) -> discuss_types.ChatResponse:

0 commit comments

Comments
 (0)