Skip to content

Commit 300ac82

Browse files
Giom-Vmarkmcd
andauthored
Adding a python script for Lyria RealTime (#776)
* Reducing the size of the TTS notebook * Further reducing the size * Adding back the open in colab button * Title * Adding a Lyria python script * lint * Colab button * nbfmt --------- Co-authored-by: Mark McDonald <[email protected]>
1 parent 5edcbee commit 300ac82

File tree

2 files changed

+264
-55
lines changed

2 files changed

+264
-55
lines changed
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
# -*- coding: utf-8 -*-
2+
# Copyright 2025 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""
17+
## Setup
18+
19+
To install the dependencies for this script, run:
20+
21+
```
22+
pip install pyaudio websockets
23+
```
24+
25+
Before running this script, ensure the `GOOGLE_API_KEY` environment
26+
variable is set to the api-key you obtained from Google AI Studio.
27+
28+
## Run
29+
30+
To run the script:
31+
32+
```
33+
python LyriaRealTime_EAP.py
34+
```
35+
36+
The script takes a prompt from the command line and streams the audio back over
37+
websockets.
38+
"""
39+
import asyncio
40+
import pyaudio
41+
import os
42+
from google import genai
43+
from google.genai import types
44+
45+
# Longer buffer reduces chance of audio drop, but also delays audio and user commands.
46+
BUFFER_SECONDS=1
47+
CHUNK=4200
48+
FORMAT=pyaudio.paInt16
49+
CHANNELS=2
50+
MODEL='models/lyria-realtime-exp'
51+
OUTPUT_RATE=48000
52+
53+
api_key = os.environ.get("GOOGLE_API_KEY")
54+
55+
if api_key is None:
56+
print("Please enter your API key")
57+
api_key = input("API Key: ").strip()
58+
59+
client = genai.Client(
60+
api_key=api_key,
61+
http_options={'api_version': 'v1alpha',}, # v1alpha since Lyria RealTime is only experimental
62+
)
63+
64+
async def main():
65+
p = pyaudio.PyAudio()
66+
config = types.LiveMusicGenerationConfig()
67+
async with client.aio.live.music.connect(model=MODEL) as session:
68+
async def receive():
69+
chunks_count = 0
70+
output_stream = p.open(
71+
format=FORMAT, channels=CHANNELS, rate=OUTPUT_RATE, output=True, frames_per_buffer=CHUNK)
72+
async for message in session.receive():
73+
chunks_count += 1
74+
if chunks_count == 1:
75+
# Introduce a delay before starting playback to have a buffer for network jitter.
76+
await asyncio.sleep(BUFFER_SECONDS)
77+
# print("Received chunk: ", message)
78+
if message.server_content:
79+
# print("Received chunk with metadata: ", message.server_content.audio_chunks[0].source_metadata)
80+
audio_data = message.server_content.audio_chunks[0].data
81+
output_stream.write(audio_data)
82+
elif message.filtered_prompt:
83+
print("Prompt was filtered out: ", message.filtered_prompt)
84+
else:
85+
print("Unknown error occured with message: ", message)
86+
await asyncio.sleep(10**-12)
87+
88+
async def send():
89+
await asyncio.sleep(5) # Allow initial prompt to play a bit
90+
91+
while True:
92+
print("Set new prompt ((bpm=<number|'AUTO'>, scale=<enum|'AUTO'>, top_k=<number|'AUTO'>, 'play', 'pause', 'prompt1:w1,prompt2:w2,...', or single text prompt)")
93+
prompt_str = await asyncio.to_thread(
94+
input,
95+
" > "
96+
)
97+
98+
if not prompt_str: # Skip empty input
99+
continue
100+
101+
if prompt_str.lower() == 'q':
102+
print("Sending STOP command.")
103+
await session.stop();
104+
return False
105+
106+
if prompt_str.lower() == 'play':
107+
print("Sending PLAY command.")
108+
await session.play()
109+
continue
110+
111+
if prompt_str.lower() == 'pause':
112+
print("Sending PAUSE command.")
113+
await session.pause()
114+
continue
115+
116+
if prompt_str.startswith('bpm='):
117+
if prompt_str.strip().endswith('AUTO'):
118+
del config.bpm
119+
print(f"Setting BPM to AUTO, which requires resetting context.")
120+
else:
121+
bpm_value = int(prompt_str.removeprefix('bpm='))
122+
print(f"Setting BPM to {bpm_value}, which requires resetting context.")
123+
config.bpm=bpm_value
124+
await session.set_music_generation_config(config=config)
125+
await session.reset_context()
126+
continue
127+
128+
if prompt_str.startswith('scale='):
129+
if prompt_str.strip().endswith('AUTO'):
130+
del config.scale
131+
print(f"Setting Scale to AUTO, which requires resetting context.")
132+
else:
133+
found_scale_enum_member = None
134+
for scale_member in types.Scale: # types.Scale is an enum
135+
if scale_member.name.lower() == prompt_str.lower():
136+
found_scale_enum_member = scale_member
137+
break
138+
if found_scale_enum_member:
139+
print(f"Setting scale to {found_scale_enum_member.name}, which requires resetting context.")
140+
config.scale = found_scale_enum_member
141+
else:
142+
print("Error: Matching enum not found.")
143+
await session.set_music_generation_config(config=config)
144+
await session.reset_context()
145+
continue
146+
147+
if prompt_str.startswith('top_k='):
148+
top_k_value = int(prompt_str.removeprefix('top_k='))
149+
print(f"Setting TopK to {top_k_value}.")
150+
config.top_k = top_k_value
151+
await session.set_music_generation_config(config=config)
152+
await session.reset_context()
153+
continue
154+
155+
# Check for multiple weighted prompts "prompt1:number1, prompt2:number2, ..."
156+
if ":" in prompt_str:
157+
parsed_prompts = []
158+
segments = prompt_str.split(',')
159+
malformed_segment_exists = False # Tracks if any segment had a parsing error
160+
161+
for segment_str_raw in segments:
162+
segment_str = segment_str_raw.strip()
163+
if not segment_str: # Skip empty segments (e.g., from "text1:1, , text2:2")
164+
continue
165+
166+
# Split on the first colon only, in case prompt text itself contains colons
167+
parts = segment_str.split(':', 1)
168+
169+
if len(parts) == 2:
170+
text_p = parts[0].strip()
171+
weight_s = parts[1].strip()
172+
173+
if not text_p: # Prompt text should not be empty
174+
print(f"Error: Empty prompt text in segment '{segment_str_raw}'. Skipping this segment.")
175+
malformed_segment_exists = True
176+
continue # Skip this malformed segment
177+
try:
178+
weight_f = float(weight_s) # Weights are floats
179+
parsed_prompts.append(types.WeightedPrompt(text=text_p, weight=weight_f))
180+
except ValueError:
181+
print(f"Error: Invalid weight '{weight_s}' in segment '{segment_str_raw}'. Must be a number. Skipping this segment.")
182+
malformed_segment_exists = True
183+
continue # Skip this malformed segment
184+
else:
185+
# This segment is not in "text:weight" format.
186+
print(f"Error: Segment '{segment_str_raw}' is not in 'text:weight' format. Skipping this segment.")
187+
malformed_segment_exists = True
188+
continue # Skip this malformed segment
189+
190+
if parsed_prompts: # If at least one prompt was successfully parsed.
191+
prompt_repr = [f"'{p.text}':{p.weight}" for p in parsed_prompts]
192+
if malformed_segment_exists:
193+
print(f"Partially sending {len(parsed_prompts)} valid weighted prompt(s) due to errors in other segments: {', '.join(prompt_repr)}")
194+
else:
195+
print(f"Sending multiple weighted prompts: {', '.join(prompt_repr)}")
196+
await session.set_weighted_prompts(prompts=parsed_prompts)
197+
else: # No valid prompts were parsed from the input string that contained ":"
198+
print("Error: Input contained ':' suggesting multi-prompt format, but no valid 'text:weight' segments were successfully parsed. No action taken.")
199+
200+
continue
201+
202+
# If none of the above, treat as a regular single text prompt
203+
print(f"Sending single text prompt: \"{prompt_str}\"")
204+
await session.set_weighted_prompts(
205+
prompts=[types.WeightedPrompt(text=prompt_str, weight=1.0)]
206+
)
207+
208+
print("Starting with some piano")
209+
await session.set_weighted_prompts(
210+
prompts=[types.WeightedPrompt(text="Piano", weight=1.0)]
211+
)
212+
213+
# Set initial BPM and Scale
214+
config.bpm = 120
215+
config.scale = types.Scale.A_FLAT_MAJOR_F_MINOR # Example initial scale
216+
print(f"Setting initial BPM to {config.bpm} and scale to {config.scale.name}")
217+
await session.set_music_generation_config(config=config)
218+
219+
print(f"Let's get the party started!")
220+
await session.play()
221+
222+
send_task = asyncio.create_task(send())
223+
receive_task = asyncio.create_task(receive())
224+
225+
# Don't quit the loop until tasks are done
226+
await asyncio.gather(send_task, receive_task)
227+
228+
# Clean up PyAudio
229+
p.terminate()
230+
231+
asyncio.run(main())

quickstarts/Get_started_TTS.ipynb

Lines changed: 33 additions & 55 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)