Skip to content

Commit 6ab1498

Browse files
DeanChensjcopybara-github
authored andcommitted
fix: Add usage_metadata and citation_metadata to DatabaseSessionService
PiperOrigin-RevId: 819900773
1 parent 2424d6a commit 6ab1498

File tree

4 files changed

+44
-24
lines changed

4 files changed

+44
-24
lines changed

src/google/adk/sessions/_session_util.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,16 @@
1616

1717
from typing import Any
1818
from typing import Optional
19+
from typing import Type
20+
from typing import TypeVar
1921

20-
from google.genai import types
22+
M = TypeVar("M")
2123

2224

23-
def decode_content(
24-
content: Optional[dict[str, Any]],
25-
) -> Optional[types.Content]:
26-
"""Decodes a content object from a JSON dictionary."""
27-
if not content:
25+
def decode_model(
26+
data: Optional[dict[str, Any]], model_cls: Type[M]
27+
) -> Optional[M]:
28+
"""Decodes a pydantic model object from a JSON dictionary."""
29+
if data is None:
2830
return None
29-
return types.Content.model_validate(content)
30-
31-
32-
def decode_grounding_metadata(
33-
grounding_metadata: Optional[dict[str, Any]],
34-
) -> Optional[types.GroundingMetadata]:
35-
"""Decodes a grounding metadata object from a JSON dictionary."""
36-
if not grounding_metadata:
37-
return None
38-
return types.GroundingMetadata.model_validate(grounding_metadata)
31+
return model_cls.model_validate(data)

src/google/adk/sessions/database_session_service.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from typing import Optional
2424
import uuid
2525

26+
from google.genai import types
2627
from sqlalchemy import Boolean
2728
from sqlalchemy import delete
2829
from sqlalchemy import Dialect
@@ -252,6 +253,12 @@ class StorageEvent(Base):
252253
custom_metadata: Mapped[dict[str, Any]] = mapped_column(
253254
DynamicJSON, nullable=True
254255
)
256+
usage_metadata: Mapped[dict[str, Any]] = mapped_column(
257+
DynamicJSON, nullable=True
258+
)
259+
citation_metadata: Mapped[dict[str, Any]] = mapped_column(
260+
DynamicJSON, nullable=True
261+
)
255262

256263
partial: Mapped[bool] = mapped_column(Boolean, nullable=True)
257264
turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True)
@@ -318,6 +325,14 @@ def from_event(cls, session: Session, event: Event) -> StorageEvent:
318325
)
319326
if event.custom_metadata:
320327
storage_event.custom_metadata = event.custom_metadata
328+
if event.usage_metadata:
329+
storage_event.usage_metadata = event.usage_metadata.model_dump(
330+
exclude_none=True, mode="json"
331+
)
332+
if event.citation_metadata:
333+
storage_event.citation_metadata = event.citation_metadata.model_dump(
334+
exclude_none=True, mode="json"
335+
)
321336
return storage_event
322337

323338
def to_event(self) -> Event:
@@ -328,17 +343,23 @@ def to_event(self) -> Event:
328343
branch=self.branch,
329344
actions=self.actions,
330345
timestamp=self.timestamp.timestamp(),
331-
content=_session_util.decode_content(self.content),
332346
long_running_tool_ids=self.long_running_tool_ids,
333347
partial=self.partial,
334348
turn_complete=self.turn_complete,
335349
error_code=self.error_code,
336350
error_message=self.error_message,
337351
interrupted=self.interrupted,
338-
grounding_metadata=_session_util.decode_grounding_metadata(
339-
self.grounding_metadata
340-
),
341352
custom_metadata=self.custom_metadata,
353+
content=_session_util.decode_model(self.content, types.Content),
354+
grounding_metadata=_session_util.decode_model(
355+
self.grounding_metadata, types.GroundingMetadata
356+
),
357+
usage_metadata=_session_util.decode_model(
358+
self.usage_metadata, types.GenerateContentResponseUsageMetadata
359+
),
360+
citation_metadata=_session_util.decode_model(
361+
self.citation_metadata, types.CitationMetadata
362+
),
342363
)
343364

344365

src/google/adk/sessions/vertex_ai_session_service.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -376,8 +376,9 @@ def _from_api_event(api_event_obj: vertexai.types.SessionEvent) -> Event:
376376
interrupted = getattr(event_metadata, 'interrupted', None)
377377
branch = getattr(event_metadata, 'branch', None)
378378
custom_metadata = getattr(event_metadata, 'custom_metadata', None)
379-
grounding_metadata = _session_util.decode_grounding_metadata(
380-
getattr(event_metadata, 'grounding_metadata', None)
379+
grounding_metadata = _session_util.decode_model(
380+
getattr(event_metadata, 'grounding_metadata', None),
381+
types.GroundingMetadata,
381382
)
382383
else:
383384
long_running_tool_ids = None
@@ -393,8 +394,8 @@ def _from_api_event(api_event_obj: vertexai.types.SessionEvent) -> Event:
393394
invocation_id=api_event_obj.invocation_id,
394395
author=api_event_obj.author,
395396
actions=event_actions,
396-
content=_session_util.decode_content(
397-
getattr(api_event_obj, 'content', None)
397+
content=_session_util.decode_model(
398+
getattr(api_event_obj, 'content', None), types.Content
398399
),
399400
timestamp=api_event_obj.timestamp.timestamp(),
400401
error_code=getattr(api_event_obj, 'error_code', None),

tests/unittests/sessions/test_session_service.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,11 @@ async def test_append_event_complete(service_type):
390390
error_code='error_code',
391391
error_message='error_message',
392392
interrupted=True,
393+
usage_metadata=types.GenerateContentResponseUsageMetadata(
394+
prompt_token_count=1, candidates_token_count=1, total_token_count=2
395+
),
396+
citation_metadata=types.CitationMetadata(),
397+
custom_metadata={'custom_key': 'custom_value'},
393398
)
394399
await session_service.append_event(session=session, event=event)
395400

0 commit comments

Comments
 (0)