Skip to content

Commit b40345c

Browse files
authored
Add Sagas (#1149)
* Add Sagas * update * update
1 parent 6470423 commit b40345c

File tree

11 files changed

+3693
-1973
lines changed

11 files changed

+3693
-1973
lines changed

examples/podcast/podcast_runner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ async def main(use_bulk: bool = False):
101101
entity_types={'Person': Person, 'City': City},
102102
edge_types={'IS_PRESIDENT_OF': IsPresidentOf},
103103
edge_type_map={('Person', 'Entity'): ['IS_PRESIDENT_OF']},
104+
saga='Freakonomics Podcast',
104105
)
105106
else:
106107
for i, message in enumerate(messages[3:14]):
@@ -119,6 +120,7 @@ async def main(use_bulk: bool = False):
119120
edge_types={'IS_PRESIDENT_OF': IsPresidentOf},
120121
edge_type_map={('Person', 'Entity'): ['PRESIDENT_OF']},
121122
previous_episode_uuids=episode_uuids,
123+
saga='Freakonomics Podcast',
122124
)
123125

124126

graphiti_core/driver/graph_operations/graph_operations.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,47 @@ async def episodic_node_delete_by_uuids(
138138
) -> None:
139139
raise NotImplementedError
140140

141+
# -----------------------
142+
# SagaNode: Save/Delete
143+
# -----------------------
144+
145+
async def saga_node_save(self, node: Any, driver: Any) -> None:
146+
"""Persist (create or update) a single saga node."""
147+
raise NotImplementedError
148+
149+
async def saga_node_delete(self, node: Any, driver: Any) -> None:
150+
raise NotImplementedError
151+
152+
async def saga_node_save_bulk(
153+
self,
154+
_cls: Any,
155+
driver: Any,
156+
transaction: Any,
157+
nodes: list[Any],
158+
batch_size: int = 100,
159+
) -> None:
160+
"""Persist (create or update) many saga nodes in batches."""
161+
raise NotImplementedError
162+
163+
async def saga_node_delete_by_group_id(
164+
self,
165+
_cls: Any,
166+
driver: Any,
167+
group_id: str,
168+
batch_size: int = 100,
169+
) -> None:
170+
raise NotImplementedError
171+
172+
async def saga_node_delete_by_uuids(
173+
self,
174+
_cls: Any,
175+
driver: Any,
176+
uuids: list[str],
177+
group_id: str | None = None,
178+
batch_size: int = 100,
179+
) -> None:
180+
raise NotImplementedError
181+
141182
# -----------------
142183
# Edge: Save/Delete
143184
# -----------------
@@ -189,3 +230,65 @@ async def edge_load_embeddings_bulk(
189230
Load embedding vectors for many edges in batches
190231
"""
191232
raise NotImplementedError
233+
234+
# ---------------------------
235+
# HasEpisodeEdge: Save/Delete
236+
# ---------------------------
237+
238+
async def has_episode_edge_save(self, edge: Any, driver: Any) -> None:
239+
"""Persist (create or update) a single has_episode edge."""
240+
raise NotImplementedError
241+
242+
async def has_episode_edge_delete(self, edge: Any, driver: Any) -> None:
243+
raise NotImplementedError
244+
245+
async def has_episode_edge_save_bulk(
246+
self,
247+
_cls: Any,
248+
driver: Any,
249+
transaction: Any,
250+
edges: list[Any],
251+
batch_size: int = 100,
252+
) -> None:
253+
"""Persist (create or update) many has_episode edges in batches."""
254+
raise NotImplementedError
255+
256+
async def has_episode_edge_delete_by_uuids(
257+
self,
258+
_cls: Any,
259+
driver: Any,
260+
uuids: list[str],
261+
group_id: str | None = None,
262+
) -> None:
263+
raise NotImplementedError
264+
265+
# ----------------------------
266+
# NextEpisodeEdge: Save/Delete
267+
# ----------------------------
268+
269+
async def next_episode_edge_save(self, edge: Any, driver: Any) -> None:
270+
"""Persist (create or update) a single next_episode edge."""
271+
raise NotImplementedError
272+
273+
async def next_episode_edge_delete(self, edge: Any, driver: Any) -> None:
274+
raise NotImplementedError
275+
276+
async def next_episode_edge_save_bulk(
277+
self,
278+
_cls: Any,
279+
driver: Any,
280+
transaction: Any,
281+
edges: list[Any],
282+
batch_size: int = 100,
283+
) -> None:
284+
"""Persist (create or update) many next_episode edges in batches."""
285+
raise NotImplementedError
286+
287+
async def next_episode_edge_delete_by_uuids(
288+
self,
289+
_cls: Any,
290+
driver: Any,
291+
uuids: list[str],
292+
group_id: str | None = None,
293+
) -> None:
294+
raise NotImplementedError

graphiti_core/driver/neptune_driver.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,9 +190,10 @@ def _sanitize_parameters(self, query, params: dict):
190190

191191
async def execute_query(
192192
self, cypher_query_, **kwargs: Any
193-
) -> tuple[dict[str, Any], None, None]:
193+
) -> tuple[list[dict[str, Any]], None, None]:
194194
params = dict(kwargs)
195195
if isinstance(cypher_query_, list):
196+
result: list[dict[str, Any]] = []
196197
for q in cypher_query_:
197198
result, _, _ = self._run_query(q[0], q[1])
198199
return result, None, None

graphiti_core/edges.py

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@
3333
COMMUNITY_EDGE_RETURN,
3434
EPISODIC_EDGE_RETURN,
3535
EPISODIC_EDGE_SAVE,
36+
HAS_EPISODE_EDGE_RETURN,
37+
HAS_EPISODE_EDGE_SAVE,
38+
NEXT_EPISODE_EDGE_RETURN,
39+
NEXT_EPISODE_EDGE_SAVE,
3640
get_community_edge_save_query,
3741
get_entity_edge_return_query,
3842
get_entity_edge_save_query,
@@ -561,6 +565,200 @@ async def get_by_group_ids(
561565
return edges
562566

563567

568+
class HasEpisodeEdge(Edge):
569+
async def save(self, driver: GraphDriver):
570+
result = await driver.execute_query(
571+
HAS_EPISODE_EDGE_SAVE,
572+
saga_uuid=self.source_node_uuid,
573+
episode_uuid=self.target_node_uuid,
574+
uuid=self.uuid,
575+
group_id=self.group_id,
576+
created_at=self.created_at,
577+
)
578+
579+
logger.debug(f'Saved edge to Graph: {self.uuid}')
580+
581+
return result
582+
583+
async def delete(self, driver: GraphDriver):
584+
await driver.execute_query(
585+
"""
586+
MATCH (n:Saga)-[e:HAS_EPISODE {uuid: $uuid}]->(m:Episodic)
587+
DELETE e
588+
""",
589+
uuid=self.uuid,
590+
)
591+
592+
logger.debug(f'Deleted Edge: {self.uuid}')
593+
594+
@classmethod
595+
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
596+
records, _, _ = await driver.execute_query(
597+
"""
598+
MATCH (n:Saga)-[e:HAS_EPISODE {uuid: $uuid}]->(m:Episodic)
599+
RETURN
600+
"""
601+
+ HAS_EPISODE_EDGE_RETURN,
602+
uuid=uuid,
603+
routing_='r',
604+
)
605+
606+
edges = [get_has_episode_edge_from_record(record) for record in records]
607+
608+
if len(edges) == 0:
609+
raise EdgeNotFoundError(uuid)
610+
return edges[0]
611+
612+
@classmethod
613+
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
614+
records, _, _ = await driver.execute_query(
615+
"""
616+
MATCH (n:Saga)-[e:HAS_EPISODE]->(m:Episodic)
617+
WHERE e.uuid IN $uuids
618+
RETURN
619+
"""
620+
+ HAS_EPISODE_EDGE_RETURN,
621+
uuids=uuids,
622+
routing_='r',
623+
)
624+
625+
edges = [get_has_episode_edge_from_record(record) for record in records]
626+
627+
return edges
628+
629+
@classmethod
630+
async def get_by_group_ids(
631+
cls,
632+
driver: GraphDriver,
633+
group_ids: list[str],
634+
limit: int | None = None,
635+
uuid_cursor: str | None = None,
636+
):
637+
cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
638+
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
639+
640+
records, _, _ = await driver.execute_query(
641+
"""
642+
MATCH (n:Saga)-[e:HAS_EPISODE]->(m:Episodic)
643+
WHERE e.group_id IN $group_ids
644+
"""
645+
+ cursor_query
646+
+ """
647+
RETURN
648+
"""
649+
+ HAS_EPISODE_EDGE_RETURN
650+
+ """
651+
ORDER BY e.uuid DESC
652+
"""
653+
+ limit_query,
654+
group_ids=group_ids,
655+
uuid=uuid_cursor,
656+
limit=limit,
657+
routing_='r',
658+
)
659+
660+
edges = [get_has_episode_edge_from_record(record) for record in records]
661+
662+
return edges
663+
664+
665+
class NextEpisodeEdge(Edge):
666+
async def save(self, driver: GraphDriver):
667+
result = await driver.execute_query(
668+
NEXT_EPISODE_EDGE_SAVE,
669+
source_episode_uuid=self.source_node_uuid,
670+
target_episode_uuid=self.target_node_uuid,
671+
uuid=self.uuid,
672+
group_id=self.group_id,
673+
created_at=self.created_at,
674+
)
675+
676+
logger.debug(f'Saved edge to Graph: {self.uuid}')
677+
678+
return result
679+
680+
async def delete(self, driver: GraphDriver):
681+
await driver.execute_query(
682+
"""
683+
MATCH (n:Episodic)-[e:NEXT_EPISODE {uuid: $uuid}]->(m:Episodic)
684+
DELETE e
685+
""",
686+
uuid=self.uuid,
687+
)
688+
689+
logger.debug(f'Deleted Edge: {self.uuid}')
690+
691+
@classmethod
692+
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
693+
records, _, _ = await driver.execute_query(
694+
"""
695+
MATCH (n:Episodic)-[e:NEXT_EPISODE {uuid: $uuid}]->(m:Episodic)
696+
RETURN
697+
"""
698+
+ NEXT_EPISODE_EDGE_RETURN,
699+
uuid=uuid,
700+
routing_='r',
701+
)
702+
703+
edges = [get_next_episode_edge_from_record(record) for record in records]
704+
705+
if len(edges) == 0:
706+
raise EdgeNotFoundError(uuid)
707+
return edges[0]
708+
709+
@classmethod
710+
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
711+
records, _, _ = await driver.execute_query(
712+
"""
713+
MATCH (n:Episodic)-[e:NEXT_EPISODE]->(m:Episodic)
714+
WHERE e.uuid IN $uuids
715+
RETURN
716+
"""
717+
+ NEXT_EPISODE_EDGE_RETURN,
718+
uuids=uuids,
719+
routing_='r',
720+
)
721+
722+
edges = [get_next_episode_edge_from_record(record) for record in records]
723+
724+
return edges
725+
726+
@classmethod
727+
async def get_by_group_ids(
728+
cls,
729+
driver: GraphDriver,
730+
group_ids: list[str],
731+
limit: int | None = None,
732+
uuid_cursor: str | None = None,
733+
):
734+
cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
735+
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
736+
737+
records, _, _ = await driver.execute_query(
738+
"""
739+
MATCH (n:Episodic)-[e:NEXT_EPISODE]->(m:Episodic)
740+
WHERE e.group_id IN $group_ids
741+
"""
742+
+ cursor_query
743+
+ """
744+
RETURN
745+
"""
746+
+ NEXT_EPISODE_EDGE_RETURN
747+
+ """
748+
ORDER BY e.uuid DESC
749+
"""
750+
+ limit_query,
751+
group_ids=group_ids,
752+
uuid=uuid_cursor,
753+
limit=limit,
754+
routing_='r',
755+
)
756+
757+
edges = [get_next_episode_edge_from_record(record) for record in records]
758+
759+
return edges
760+
761+
564762
# Edge helpers
565763
def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
566764
return EpisodicEdge(
@@ -620,6 +818,26 @@ def get_community_edge_from_record(record: Any):
620818
)
621819

622820

821+
def get_has_episode_edge_from_record(record: Any) -> HasEpisodeEdge:
822+
return HasEpisodeEdge(
823+
uuid=record['uuid'],
824+
group_id=record['group_id'],
825+
source_node_uuid=record['source_node_uuid'],
826+
target_node_uuid=record['target_node_uuid'],
827+
created_at=parse_db_date(record['created_at']), # type: ignore
828+
)
829+
830+
831+
def get_next_episode_edge_from_record(record: Any) -> NextEpisodeEdge:
832+
return NextEpisodeEdge(
833+
uuid=record['uuid'],
834+
group_id=record['group_id'],
835+
source_node_uuid=record['source_node_uuid'],
836+
target_node_uuid=record['target_node_uuid'],
837+
created_at=parse_db_date(record['created_at']), # type: ignore
838+
)
839+
840+
623841
async def create_entity_edge_embeddings(embedder: EmbedderClient, edges: list[EntityEdge]):
624842
# filter out falsey values from edges
625843
filtered_edges = [edge for edge in edges if edge.fact]

0 commit comments

Comments
 (0)