Skip to content

Commit 8ab33b4

Browse files
committed
Refactoring
1 parent 098e0a5 commit 8ab33b4

File tree

1 file changed

+31
-66
lines changed

1 file changed

+31
-66
lines changed

src/a2a/server/tasks/database_task_store.py

Lines changed: 31 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -85,85 +85,50 @@ async def _ensure_initialized(self) -> None:
8585
if not self._initialized:
8686
await self.initialize()
8787

88+
def _to_orm(self, task: Task) -> TaskModel:
89+
"""Maps a Pydantic Task to a SQLAlchemy TaskModel instance."""
90+
return self.task_model(
91+
id=task.id,
92+
contextId=task.contextId,
93+
kind=task.kind,
94+
status=task.status,
95+
artifacts=task.artifacts,
96+
history=task.history,
97+
task_metadata=task.metadata,
98+
)
99+
100+
def _from_orm(self, task_model: TaskModel) -> Task:
101+
"""Maps a SQLAlchemy TaskModel to a Pydantic Task instance."""
102+
# Map database columns to Pydantic model fields
103+
task_data_from_db = {
104+
'id': task_model.id,
105+
'contextId': task_model.contextId,
106+
'kind': task_model.kind,
107+
'status': task_model.status,
108+
'artifacts': task_model.artifacts,
109+
'history': task_model.history,
110+
'metadata': task_model.task_metadata, # Map task_metadata column to metadata field
111+
}
112+
# Pydantic's model_validate will parse the nested dicts/lists from JSON
113+
return Task.model_validate(task_data_from_db)
114+
88115
async def save(self, task: Task) -> None:
89116
"""Saves or updates a task in the database."""
90117
await self._ensure_initialized()
91-
92-
task_data = task.model_dump(
93-
mode='json'
94-
) # Converts Pydantic Task to dict with JSON-serializable values
95-
118+
db_task = self._to_orm(task)
96119
async with self.async_session_maker.begin() as session:
97-
stmt_select = select(self.task_model).where(
98-
self.task_model.id == task.id
99-
)
100-
result = await session.execute(stmt_select)
101-
existing_task_model = result.scalar_one_or_none()
102-
103-
if existing_task_model:
104-
logger.debug(f'Updating task {task.id} in the database.')
105-
update_data = {
106-
'contextId': task_data['contextId'],
107-
'kind': task_data['kind'],
108-
'status': task_data[
109-
'status'
110-
], # Already a dict from model_dump
111-
'artifacts': task_data.get(
112-
'artifacts'
113-
), # Already a list of dicts
114-
'history': task_data.get(
115-
'history'
116-
), # Already a list of dicts
117-
'task_metadata': task_data.get(
118-
'metadata'
119-
), # Already a dict
120-
}
121-
stmt_update = (
122-
update(self.task_model)
123-
.where(self.task_model.id == task.id)
124-
.values(**update_data)
125-
)
126-
await session.execute(stmt_update)
127-
logger.debug(f'Task {task.id} updated successfully.')
128-
else:
129-
logger.debug(f'Saving new task {task.id} to the database.')
130-
# Map Pydantic fields to database columns
131-
new_task_model = self.task_model(
132-
id=task_data['id'],
133-
contextId=task_data['contextId'],
134-
kind=task_data['kind'],
135-
status=task_data['status'],
136-
artifacts=task_data.get('artifacts'),
137-
history=task_data.get('history'),
138-
task_metadata=task_data.get(
139-
'metadata'
140-
), # Map metadata field to task_metadata column
141-
)
142-
session.add(new_task_model)
143-
logger.info(f'Task {task.id} created successfully.')
120+
await session.merge(db_task)
121+
logger.debug(f'Task {task.id} saved/updated successfully.')
144122

145123
async def get(self, task_id: str) -> Task | None:
146124
"""Retrieves a task from the database by ID."""
147125
await self._ensure_initialized()
148-
149126
async with self.async_session_maker() as session:
150127
stmt = select(self.task_model).where(self.task_model.id == task_id)
151128
result = await session.execute(stmt)
152129
task_model = result.scalar_one_or_none()
153-
154130
if task_model:
155-
# Map database columns to Pydantic model fields
156-
task_data_from_db = {
157-
'id': task_model.id,
158-
'contextId': task_model.contextId,
159-
'kind': task_model.kind,
160-
'status': task_model.status,
161-
'artifacts': task_model.artifacts,
162-
'history': task_model.history,
163-
'metadata': task_model.task_metadata, # Map task_metadata column to metadata field
164-
}
165-
# Pydantic's model_validate will parse the nested dicts/lists from JSON
166-
task = Task.model_validate(task_data_from_db)
131+
task = self._from_orm(task_model)
167132
logger.debug(f'Task {task_id} retrieved successfully.')
168133
return task
169134

0 commit comments

Comments
 (0)