@@ -38,12 +38,12 @@ class CosmosDBClient(DatabaseBase):
3838    """CosmosDB implementation of the database interface.""" 
3939
4040    MODEL_CLASS_MAPPING  =  {
41-         " session"  : Session ,
42-         " plan"  : Plan ,
43-         " step"  : Step ,
44-         " agent_message"  : AgentMessage ,
45-         " team_config"  : TeamConfiguration ,
46-         " user_current_team"  : UserCurrentTeam ,
41+         DataType . session : Session ,
42+         DataType . plan : Plan ,
43+         DataType . step : Step ,
44+         DataType . agent_message : AgentMessage ,
45+         DataType . team_config : TeamConfiguration ,
46+         DataType . user_current_team : UserCurrentTeam ,
4747    }
4848
4949    def  __init__ (
@@ -200,7 +200,7 @@ async def get_session(self, session_id: str) -> Optional[Session]:
200200        query  =  "SELECT * FROM c WHERE c.id=@id AND c.data_type=@data_type" 
201201        parameters  =  [
202202            {"name" : "@id" , "value" : session_id },
203-             {"name" : "@data_type" , "value" : " session"  },
203+             {"name" : "@data_type" , "value" : DataType . session },
204204        ]
205205        results  =  await  self .query_items (query , parameters , Session )
206206        return  results [0 ] if  results  else  None 
@@ -210,7 +210,7 @@ async def get_all_sessions(self) -> List[Session]:
210210        query  =  "SELECT * FROM c WHERE c.user_id=@user_id AND c.data_type=@data_type" 
211211        parameters  =  [
212212            {"name" : "@user_id" , "value" : self .user_id },
213-             {"name" : "@data_type" , "value" : " session"  },
213+             {"name" : "@data_type" , "value" : DataType . session },
214214        ]
215215        return  await  self .query_items (query , parameters , Session )
216216
@@ -230,7 +230,7 @@ async def get_plan_by_session(self, session_id: str) -> Optional[Plan]:
230230        )
231231        parameters  =  [
232232            {"name" : "@session_id" , "value" : session_id },
233-             {"name" : "@data_type" , "value" : " plan"  },
233+             {"name" : "@data_type" , "value" : DataType . plan },
234234        ]
235235        results  =  await  self .query_items (query , parameters , Plan )
236236        return  results [0 ] if  results  else  None 
@@ -240,7 +240,7 @@ async def get_plan_by_plan_id(self, plan_id: str) -> Optional[Plan]:
240240        query  =  "SELECT * FROM c WHERE c.id=@plan_id AND c.data_type=@data_type" 
241241        parameters  =  [
242242            {"name" : "@plan_id" , "value" : plan_id },
243-             {"name" : "@data_type" , "value" : " plan"  },
243+             {"name" : "@data_type" , "value" : DataType . plan },
244244            {"name" : "@user_id" , "value" : self .user_id },
245245        ]
246246        results  =  await  self .query_items (query , parameters , Plan )
@@ -255,7 +255,7 @@ async def get_all_plans(self) -> List[Plan]:
255255        query  =  "SELECT * FROM c WHERE c.user_id=@user_id AND c.data_type=@data_type" 
256256        parameters  =  [
257257            {"name" : "@user_id" , "value" : self .user_id },
258-             {"name" : "@data_type" , "value" : " plan"  },
258+             {"name" : "@data_type" , "value" : DataType . plan },
259259        ]
260260        return  await  self .query_items (query , parameters , Plan )
261261
@@ -265,7 +265,7 @@ async def get_all_plans_by_team_id(self, team_id: str) -> List[Plan]:
265265        parameters  =  [
266266            {"name" : "@user_id" , "value" : self .user_id },
267267            {"name" : "@team_id" , "value" : team_id },
268-             {"name" : "@data_type" , "value" : " plan"  },
268+             {"name" : "@data_type" , "value" : DataType . plan },
269269        ]
270270        return  await  self .query_items (query , parameters , Plan )
271271
@@ -283,7 +283,7 @@ async def get_steps_by_plan(self, plan_id: str) -> List[Step]:
283283        query  =  "SELECT * FROM c WHERE c.plan_id=@plan_id AND c.data_type=@data_type ORDER BY c.timestamp" 
284284        parameters  =  [
285285            {"name" : "@plan_id" , "value" : plan_id },
286-             {"name" : "@data_type" , "value" : " step"  },
286+             {"name" : "@data_type" , "value" : DataType . step },
287287        ]
288288        return  await  self .query_items (query , parameters , Step )
289289
@@ -293,7 +293,7 @@ async def get_step(self, step_id: str, session_id: str) -> Optional[Step]:
293293        parameters  =  [
294294            {"name" : "@step_id" , "value" : step_id },
295295            {"name" : "@session_id" , "value" : session_id },
296-             {"name" : "@data_type" , "value" : " step"  },
296+             {"name" : "@data_type" , "value" : DataType . step },
297297        ]
298298        results  =  await  self .query_items (query , parameters , Step )
299299        return  results [0 ] if  results  else  None 
@@ -312,7 +312,7 @@ async def get_team(self, team_id: str) -> Optional[TeamConfiguration]:
312312        query  =  "SELECT * FROM c WHERE c.team_id=@team_id AND c.data_type=@data_type" 
313313        parameters  =  [
314314            {"name" : "@team_id" , "value" : team_id },
315-             {"name" : "@data_type" , "value" : " team_config"  },
315+             {"name" : "@data_type" , "value" : DataType . team_config },
316316        ]
317317        teams  =  await  self .query_items (query , parameters , TeamConfiguration )
318318        return  teams [0 ] if  teams  else  None 
@@ -329,7 +329,7 @@ async def get_team_by_id(self, id: str) -> Optional[TeamConfiguration]:
329329        query  =  "SELECT * FROM c WHERE c.id=@id AND c.data_type=@data_type" 
330330        parameters  =  [
331331            {"name" : "@id" , "value" : id },
332-             {"name" : "@data_type" , "value" : " team_config"  },
332+             {"name" : "@data_type" , "value" : DataType . team_config },
333333        ]
334334        teams  =  await  self .query_items (query , parameters , TeamConfiguration )
335335        return  teams [0 ] if  teams  else  None 
@@ -346,7 +346,7 @@ async def get_all_teams_by_user(self, user_id: str) -> List[TeamConfiguration]:
346346        query  =  "SELECT * FROM c WHERE c.user_id=@user_id AND c.data_type=@data_type ORDER BY c.created DESC" 
347347        parameters  =  [
348348            {"name" : "@user_id" , "value" : user_id },
349-             {"name" : "@data_type" , "value" : " team_config"  },
349+             {"name" : "@data_type" , "value" : DataType . team_config },
350350        ]
351351        teams  =  await  self .query_items (query , parameters , TeamConfiguration )
352352        return  teams 
@@ -452,7 +452,7 @@ async def get_current_team(self, user_id: str) -> Optional[UserCurrentTeam]:
452452
453453        query  =  "SELECT * FROM c WHERE c.data_type=@data_type AND c.user_id=@user_id" 
454454        parameters  =  [
455-             {"name" : "@data_type" , "value" : " user_current_team"  },
455+             {"name" : "@data_type" , "value" : DataType . user_current_team },
456456            {"name" : "@user_id" , "value" : user_id },
457457        ]
458458
0 commit comments