Skip to content

feat(genai): Add Live API samples v2 #13523

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions genai/live/live_ground_ragengine_with_txt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio

_memory_corpus = "projects/cloud-ai-devrel-softserve/locations/us-central1/ragCorpora/2305843009213693952"
Comment on lines +14 to +16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better reusability and to avoid exposing project-specific details, it's recommended to construct the _memory_corpus string from environment variables for the project ID and a placeholder for the user-specific RAG corpus ID. This requires importing the os module and aligns with the practices in other samples.

import asyncio
import os

# TODO(developer): Set this to your RAG Corpus ID.
RAG_CORPUS_ID = "your-rag-corpus-id"
_memory_corpus = f"projects/{os.getenv('GOOGLE_CLOUD_PROJECT')}/locations/us-central1/ragCorpora/{RAG_CORPUS_ID}"



async def generate_content(memory_corpus: str) -> list[str]:
# [START googlegenaisdk_live_ground_ragengine_with_txt]
from google import genai
from google.genai.types import (
Content,
LiveConnectConfig,
Modality,
Part,
Tool,
Retrieval,
VertexRagStore,
VertexRagStoreRagResource,
)

client = genai.Client()
model_id = "gemini-2.0-flash-live-preview-04-09"
rag_store = VertexRagStore(
rag_resources=[
VertexRagStoreRagResource(
rag_corpus=memory_corpus # Use memory corpus if you want to store context.
)
],
# Set `store_context` to true to allow Live API sink context into your memory corpus.
store_context=True,
)
config = LiveConnectConfig(
response_modalities=[Modality.TEXT],
tools=[Tool(retrieval=Retrieval(vertex_rag_store=rag_store))],
)

async with client.aio.live.connect(model=model_id, config=config) as session:
text_input = "What year did Mariusz Pudzianowski win World's Strongest Man?"
print("> ", text_input, "\n")

await session.send_client_content(
turns=Content(role="user", parts=[Part(text=text_input)])
)

response = []

async for message in session.receive():
if message.text:
response.append(message.text)
continue

print("".join(response))
# Example output:
# > What year did Mariusz Pudzianowski win World's Strongest Man?
# Mariusz Pudzianowski won World's Strongest Man in 2002, 2003, 2005, 2007, and 2008.
# [END googlegenaisdk_live_ground_ragengine_with_txt]
return response


if __name__ == "__main__":
asyncio.run(generate_content(_memory_corpus))
28 changes: 17 additions & 11 deletions genai/live/live_websocket_audiogen_with_txt.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ def get_bearer_token() -> str:
import google.auth
from google.auth.transport.requests import Request

creds, _ = google.auth.default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
creds, _ = google.auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"]
)
auth_req = Request()
creds.refresh(auth_req)
bearer_token = creds.token
Expand Down Expand Up @@ -55,9 +57,7 @@ async def generate_content() -> str:

# Websocket Configuration
WEBSOCKET_HOST = "us-central1-aiplatform.googleapis.com"
WEBSOCKET_SERVICE_URL = (
f"wss://{WEBSOCKET_HOST}/ws/google.cloud.aiplatform.v1.LlmBidiService/BidiGenerateContent"
)
WEBSOCKET_SERVICE_URL = f"wss://{WEBSOCKET_HOST}/ws/google.cloud.aiplatform.v1.LlmBidiService/BidiGenerateContent"

# Websocket Authentication
headers = {
Expand All @@ -66,9 +66,7 @@ async def generate_content() -> str:
}

# Model Configuration
model_path = (
f"projects/{PROJECT_ID}/locations/{LOCATION}/publishers/google/models/{GEMINI_MODEL_NAME}"
)
model_path = f"projects/{PROJECT_ID}/locations/{LOCATION}/publishers/google/models/{GEMINI_MODEL_NAME}"
model_generation_config = {
"response_modalities": ["AUDIO"],
"speech_config": {
Expand All @@ -77,7 +75,9 @@ async def generate_content() -> str:
},
}

async with connect(WEBSOCKET_SERVICE_URL, additional_headers=headers) as websocket_session:
async with connect(
WEBSOCKET_SERVICE_URL, additional_headers=headers
) as websocket_session:
# 1. Send setup configuration
websocket_config = {
"setup": {
Expand Down Expand Up @@ -120,7 +120,9 @@ async def generate_content() -> str:
server_content = response_chunk.get("serverContent")
if not server_content:
# This might indicate an error or an unexpected message format
print(f"Received non-serverContent message or empty content: {response_chunk}")
print(
f"Received non-serverContent message or empty content: {response_chunk}"
)
break

# Collect audio chunks
Expand All @@ -129,15 +131,19 @@ async def generate_content() -> str:
for part in model_turn["parts"]:
if part["inlineData"]["mimeType"] == "audio/pcm":
audio_chunk = base64.b64decode(part["inlineData"]["data"])
aggregated_response_parts.append(np.frombuffer(audio_chunk, dtype=np.int16))
aggregated_response_parts.append(
np.frombuffer(audio_chunk, dtype=np.int16)
)

# End of response
if server_content.get("turnComplete"):
break

# Save audio to a file
if aggregated_response_parts:
wavfile.write("output.wav", 24000, np.concatenate(aggregated_response_parts))
wavfile.write(
"output.wav", 24000, np.concatenate(aggregated_response_parts)
)
# Example response:
# Setup Response: {'setupComplete': {}}
# Input: Hello? Gemini are you there?
Expand Down
24 changes: 14 additions & 10 deletions genai/live/live_websocket_audiotranscript_with_txt.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ def get_bearer_token() -> str:
import google.auth
from google.auth.transport.requests import Request

creds, _ = google.auth.default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
creds, _ = google.auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"]
)
auth_req = Request()
creds.refresh(auth_req)
bearer_token = creds.token
Expand Down Expand Up @@ -55,9 +57,7 @@ async def generate_content() -> str:

# Websocket Configuration
WEBSOCKET_HOST = "us-central1-aiplatform.googleapis.com"
WEBSOCKET_SERVICE_URL = (
f"wss://{WEBSOCKET_HOST}/ws/google.cloud.aiplatform.v1.LlmBidiService/BidiGenerateContent"
)
WEBSOCKET_SERVICE_URL = f"wss://{WEBSOCKET_HOST}/ws/google.cloud.aiplatform.v1.LlmBidiService/BidiGenerateContent"

# Websocket Authentication
headers = {
Expand All @@ -66,9 +66,7 @@ async def generate_content() -> str:
}

# Model Configuration
model_path = (
f"projects/{PROJECT_ID}/locations/{LOCATION}/publishers/google/models/{GEMINI_MODEL_NAME}"
)
model_path = f"projects/{PROJECT_ID}/locations/{LOCATION}/publishers/google/models/{GEMINI_MODEL_NAME}"
model_generation_config = {
"response_modalities": ["AUDIO"],
"speech_config": {
Expand All @@ -77,7 +75,9 @@ async def generate_content() -> str:
},
}

async with connect(WEBSOCKET_SERVICE_URL, additional_headers=headers) as websocket_session:
async with connect(
WEBSOCKET_SERVICE_URL, additional_headers=headers
) as websocket_session:
# 1. Send setup configuration
websocket_config = {
"setup": {
Expand Down Expand Up @@ -125,7 +125,9 @@ async def generate_content() -> str:
server_content = response_chunk.get("serverContent")
if not server_content:
# This might indicate an error or an unexpected message format
print(f"Received non-serverContent message or empty content: {response_chunk}")
print(
f"Received non-serverContent message or empty content: {response_chunk}"
)
break

# Transcriptions
Expand All @@ -142,7 +144,9 @@ async def generate_content() -> str:
for part in model_turn["parts"]:
if part["inlineData"]["mimeType"] == "audio/pcm":
audio_chunk = base64.b64decode(part["inlineData"]["data"])
aggregated_response_parts.append(np.frombuffer(audio_chunk, dtype=np.int16))
aggregated_response_parts.append(
np.frombuffer(audio_chunk, dtype=np.int16)
)

# End of response
if server_content.get("turnComplete"):
Expand Down
24 changes: 14 additions & 10 deletions genai/live/live_websocket_textgen_with_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ def get_bearer_token() -> str:
import google.auth
from google.auth.transport.requests import Request

creds, _ = google.auth.default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
creds, _ = google.auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"]
)
auth_req = Request()
creds.refresh(auth_req)
bearer_token = creds.token
Expand Down Expand Up @@ -65,9 +67,7 @@ def read_wavefile(filepath: str) -> tuple[str, str]:

# Websocket Configuration
WEBSOCKET_HOST = "us-central1-aiplatform.googleapis.com"
WEBSOCKET_SERVICE_URL = (
f"wss://{WEBSOCKET_HOST}/ws/google.cloud.aiplatform.v1.LlmBidiService/BidiGenerateContent"
)
WEBSOCKET_SERVICE_URL = f"wss://{WEBSOCKET_HOST}/ws/google.cloud.aiplatform.v1.LlmBidiService/BidiGenerateContent"

# Websocket Authentication
headers = {
Expand All @@ -76,12 +76,12 @@ def read_wavefile(filepath: str) -> tuple[str, str]:
}

# Model Configuration
model_path = (
f"projects/{PROJECT_ID}/locations/{LOCATION}/publishers/google/models/{GEMINI_MODEL_NAME}"
)
model_path = f"projects/{PROJECT_ID}/locations/{LOCATION}/publishers/google/models/{GEMINI_MODEL_NAME}"
model_generation_config = {"response_modalities": ["TEXT"]}

async with connect(WEBSOCKET_SERVICE_URL, additional_headers=headers) as websocket_session:
async with connect(
WEBSOCKET_SERVICE_URL, additional_headers=headers
) as websocket_session:
# 1. Send setup configuration
websocket_config = {
"setup": {
Expand All @@ -105,7 +105,9 @@ def read_wavefile(filepath: str) -> tuple[str, str]:
return "Error: WebSocket setup failed."

# 3. Send audio message
encoded_audio_message, mime_type = read_wavefile("hello_gemini_are_you_there.wav")
encoded_audio_message, mime_type = read_wavefile(
"hello_gemini_are_you_there.wav"
)
# Example audio message: "Hello? Gemini are you there?"

user_message = {
Expand Down Expand Up @@ -136,7 +138,9 @@ def read_wavefile(filepath: str) -> tuple[str, str]:
server_content = response_chunk.get("serverContent")
if not server_content:
# This might indicate an error or an unexpected message format
print(f"Received non-serverContent message or empty content: {response_chunk}")
print(
f"Received non-serverContent message or empty content: {response_chunk}"
)
break

# Collect text responses
Expand Down
20 changes: 11 additions & 9 deletions genai/live/live_websocket_textgen_with_txt.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ def get_bearer_token() -> str:
import google.auth
from google.auth.transport.requests import Request

creds, _ = google.auth.default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
creds, _ = google.auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"]
)
auth_req = Request()
creds.refresh(auth_req)
bearer_token = creds.token
Expand Down Expand Up @@ -51,9 +53,7 @@ async def generate_content() -> str:

# Websocket Configuration
WEBSOCKET_HOST = "us-central1-aiplatform.googleapis.com"
WEBSOCKET_SERVICE_URL = (
f"wss://{WEBSOCKET_HOST}/ws/google.cloud.aiplatform.v1.LlmBidiService/BidiGenerateContent"
)
WEBSOCKET_SERVICE_URL = f"wss://{WEBSOCKET_HOST}/ws/google.cloud.aiplatform.v1.LlmBidiService/BidiGenerateContent"

# Websocket Authentication
headers = {
Expand All @@ -62,12 +62,12 @@ async def generate_content() -> str:
}

# Model Configuration
model_path = (
f"projects/{PROJECT_ID}/locations/{LOCATION}/publishers/google/models/{GEMINI_MODEL_NAME}"
)
model_path = f"projects/{PROJECT_ID}/locations/{LOCATION}/publishers/google/models/{GEMINI_MODEL_NAME}"
model_generation_config = {"response_modalities": ["TEXT"]}

async with connect(WEBSOCKET_SERVICE_URL, additional_headers=headers) as websocket_session:
async with connect(
WEBSOCKET_SERVICE_URL, additional_headers=headers
) as websocket_session:
# 1. Send setup configuration
websocket_config = {
"setup": {
Expand Down Expand Up @@ -110,7 +110,9 @@ async def generate_content() -> str:
server_content = response_chunk.get("serverContent")
if not server_content:
# This might indicate an error or an unexpected message format
print(f"Received non-serverContent message or empty content: {response_chunk}")
print(
f"Received non-serverContent message or empty content: {response_chunk}"
)
break

# Collect text responses
Expand Down
4 changes: 3 additions & 1 deletion genai/live/live_with_txt.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ async def generate_content() -> list[str]:
) as session:
text_input = "Hello? Gemini, are you there?"
print("> ", text_input, "\n")
await session.send_client_content(turns=Content(role="user", parts=[Part(text=text_input)]))
await session.send_client_content(
turns=Content(role="user", parts=[Part(text=text_input)])
)

response = []

Expand Down
29 changes: 29 additions & 0 deletions genai/live/test_live_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,37 @@
import live_websocket_textgen_with_audio
import live_websocket_textgen_with_txt
import live_with_txt
import live_ground_ragengine_with_txt

os.environ["GOOGLE_GENAI_USE_VERTEXAI"] = "True"
os.environ["GOOGLE_CLOUD_LOCATION"] = "us-central1"
# The project name is included in the CICD pipeline
# os.environ['GOOGLE_CLOUD_PROJECT'] = "add-your-project-name"


@pytest.fixture()
def mock_rag_components(mocker):
mock_client_cls = mocker.patch("google.genai.Client")

class AsyncIterator:
def __aiter__(self):
return self

async def __anext__(self):
if not hasattr(self, "used"):
self.used = True
return mocker.MagicMock(
text="Mariusz Pudzianowski won in 2002, 2003, 2005, 2007, and 2008."
)
raise StopAsyncIteration

mock_session = mocker.AsyncMock()
mock_session.__aenter__.return_value = mock_session
mock_session.receive = lambda: AsyncIterator()

mock_client_cls.return_value.aio.live.connect.return_value = mock_session


@pytest.mark.asyncio
async def test_live_with_text() -> None:
assert await live_with_txt.generate_content()
Expand All @@ -55,3 +79,8 @@ async def test_live_websocket_audiogen_with_txt() -> None:
@pytest.mark.asyncio
async def test_live_websocket_audiotranscript_with_txt() -> None:
assert await live_websocket_audiotranscript_with_txt.generate_content()


@pytest.mark.asyncio
async def test_live_ground_ragengine_with_txt(mock_rag_components) -> None:
assert await live_ground_ragengine_with_txt.generate_content("test")