@@ -85,85 +85,50 @@ async def _ensure_initialized(self) -> None:
8585 if not self ._initialized :
8686 await self .initialize ()
8787
88+ def _to_orm (self , task : Task ) -> TaskModel :
89+ """Maps a Pydantic Task to a SQLAlchemy TaskModel instance."""
90+ return self .task_model (
91+ id = task .id ,
92+ contextId = task .contextId ,
93+ kind = task .kind ,
94+ status = task .status ,
95+ artifacts = task .artifacts ,
96+ history = task .history ,
97+ task_metadata = task .metadata ,
98+ )
99+
100+ def _from_orm (self , task_model : TaskModel ) -> Task :
101+ """Maps a SQLAlchemy TaskModel to a Pydantic Task instance."""
102+ # Map database columns to Pydantic model fields
103+ task_data_from_db = {
104+ 'id' : task_model .id ,
105+ 'contextId' : task_model .contextId ,
106+ 'kind' : task_model .kind ,
107+ 'status' : task_model .status ,
108+ 'artifacts' : task_model .artifacts ,
109+ 'history' : task_model .history ,
110+ 'metadata' : task_model .task_metadata , # Map task_metadata column to metadata field
111+ }
112+ # Pydantic's model_validate will parse the nested dicts/lists from JSON
113+ return Task .model_validate (task_data_from_db )
114+
88115 async def save (self , task : Task ) -> None :
89116 """Saves or updates a task in the database."""
90117 await self ._ensure_initialized ()
91-
92- task_data = task .model_dump (
93- mode = 'json'
94- ) # Converts Pydantic Task to dict with JSON-serializable values
95-
118+ db_task = self ._to_orm (task )
96119 async with self .async_session_maker .begin () as session :
97- stmt_select = select (self .task_model ).where (
98- self .task_model .id == task .id
99- )
100- result = await session .execute (stmt_select )
101- existing_task_model = result .scalar_one_or_none ()
102-
103- if existing_task_model :
104- logger .debug (f'Updating task { task .id } in the database.' )
105- update_data = {
106- 'contextId' : task_data ['contextId' ],
107- 'kind' : task_data ['kind' ],
108- 'status' : task_data [
109- 'status'
110- ], # Already a dict from model_dump
111- 'artifacts' : task_data .get (
112- 'artifacts'
113- ), # Already a list of dicts
114- 'history' : task_data .get (
115- 'history'
116- ), # Already a list of dicts
117- 'task_metadata' : task_data .get (
118- 'metadata'
119- ), # Already a dict
120- }
121- stmt_update = (
122- update (self .task_model )
123- .where (self .task_model .id == task .id )
124- .values (** update_data )
125- )
126- await session .execute (stmt_update )
127- logger .debug (f'Task { task .id } updated successfully.' )
128- else :
129- logger .debug (f'Saving new task { task .id } to the database.' )
130- # Map Pydantic fields to database columns
131- new_task_model = self .task_model (
132- id = task_data ['id' ],
133- contextId = task_data ['contextId' ],
134- kind = task_data ['kind' ],
135- status = task_data ['status' ],
136- artifacts = task_data .get ('artifacts' ),
137- history = task_data .get ('history' ),
138- task_metadata = task_data .get (
139- 'metadata'
140- ), # Map metadata field to task_metadata column
141- )
142- session .add (new_task_model )
143- logger .info (f'Task { task .id } created successfully.' )
120+ await session .merge (db_task )
121+ logger .debug (f'Task { task .id } saved/updated successfully.' )
144122
145123 async def get (self , task_id : str ) -> Task | None :
146124 """Retrieves a task from the database by ID."""
147125 await self ._ensure_initialized ()
148-
149126 async with self .async_session_maker () as session :
150127 stmt = select (self .task_model ).where (self .task_model .id == task_id )
151128 result = await session .execute (stmt )
152129 task_model = result .scalar_one_or_none ()
153-
154130 if task_model :
155- # Map database columns to Pydantic model fields
156- task_data_from_db = {
157- 'id' : task_model .id ,
158- 'contextId' : task_model .contextId ,
159- 'kind' : task_model .kind ,
160- 'status' : task_model .status ,
161- 'artifacts' : task_model .artifacts ,
162- 'history' : task_model .history ,
163- 'metadata' : task_model .task_metadata , # Map task_metadata column to metadata field
164- }
165- # Pydantic's model_validate will parse the nested dicts/lists from JSON
166- task = Task .model_validate (task_data_from_db )
131+ task = self ._from_orm (task_model )
167132 logger .debug (f'Task { task_id } retrieved successfully.' )
168133 return task
169134
0 commit comments