Skip to content

Commit 6e5c0eb

Browse files
wukathcopybara-github
authored andcommitted
feat: Add token usage to live events for bidi streaming
Populate the usage_metadata field for live events with the metadata provided by the Gemini live API. Co-authored-by: Kathy Wu <[email protected]> PiperOrigin-RevId: 828124232
1 parent 4f85e86 commit 6e5c0eb

File tree

3 files changed

+73
-0
lines changed

3 files changed

+73
-0
lines changed

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,7 @@ async def _postprocess_live(
581581
and not llm_response.turn_complete
582582
and not llm_response.input_transcription
583583
and not llm_response.output_transcription
584+
and not llm_response.usage_metadata
584585
):
585586
return
586587

src/google/adk/models/gemini_llm_connection.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
148148
# partial content and emit responses as needed.
149149
async for message in agen:
150150
logger.debug('Got LLM Live message: %s', message)
151+
if message.usage_metadata:
152+
yield LlmResponse(usage_metadata=message.usage_metadata)
151153
if message.server_content:
152154
content = message.server_content.model_turn
153155
if content and content.parts:

tests/unittests/models/test_gemini_llm_connection.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,73 @@ async def test_close(gemini_connection, mock_gemini_session):
109109
await gemini_connection.close()
110110

111111
mock_gemini_session.close.assert_called_once()
112+
113+
114+
@pytest.mark.asyncio
115+
async def test_receive_usage_metadata_and_server_content(
116+
gemini_connection, mock_gemini_session
117+
):
118+
"""Test receive with usage metadata and server content in one message."""
119+
usage_metadata = types.UsageMetadata(
120+
prompt_token_count=10,
121+
cached_content_token_count=5,
122+
response_token_count=20,
123+
total_token_count=35,
124+
thoughts_token_count=2,
125+
prompt_tokens_details=[
126+
types.ModalityTokenCount(modality='text', token_count=10)
127+
],
128+
cache_tokens_details=[
129+
types.ModalityTokenCount(modality='text', token_count=5)
130+
],
131+
response_tokens_details=[
132+
types.ModalityTokenCount(modality='text', token_count=20)
133+
],
134+
)
135+
mock_content = types.Content(
136+
role='model', parts=[types.Part.from_text(text='response text')]
137+
)
138+
mock_server_content = mock.Mock()
139+
mock_server_content.model_turn = mock_content
140+
mock_server_content.interrupted = False
141+
mock_server_content.input_transcription = None
142+
mock_server_content.output_transcription = None
143+
mock_server_content.turn_complete = False
144+
145+
mock_message = mock.AsyncMock()
146+
mock_message.usage_metadata = usage_metadata
147+
mock_message.server_content = mock_server_content
148+
mock_message.tool_call = None
149+
mock_message.session_resumption_update = None
150+
151+
async def mock_receive_generator():
152+
yield mock_message
153+
154+
receive_mock = mock.Mock(return_value=mock_receive_generator())
155+
mock_gemini_session.receive = receive_mock
156+
157+
responses = [resp async for resp in gemini_connection.receive()]
158+
159+
assert responses
160+
161+
usage_response = next((r for r in responses if r.usage_metadata), None)
162+
assert usage_response is not None
163+
content_response = next((r for r in responses if r.content), None)
164+
assert content_response is not None
165+
166+
expected_usage = types.GenerateContentResponseUsageMetadata(
167+
prompt_token_count=10,
168+
cached_content_token_count=5,
169+
candidates_token_count=None,
170+
total_token_count=35,
171+
thoughts_token_count=2,
172+
prompt_tokens_details=[
173+
types.ModalityTokenCount(modality='text', token_count=10)
174+
],
175+
cache_tokens_details=[
176+
types.ModalityTokenCount(modality='text', token_count=5)
177+
],
178+
candidates_tokens_details=None,
179+
)
180+
assert usage_response.usage_metadata == expected_usage
181+
assert content_response.content == mock_content

0 commit comments

Comments
 (0)