|
33 | 33 | COMMUNITY_EDGE_RETURN, |
34 | 34 | EPISODIC_EDGE_RETURN, |
35 | 35 | EPISODIC_EDGE_SAVE, |
| 36 | + HAS_EPISODE_EDGE_RETURN, |
| 37 | + HAS_EPISODE_EDGE_SAVE, |
| 38 | + NEXT_EPISODE_EDGE_RETURN, |
| 39 | + NEXT_EPISODE_EDGE_SAVE, |
36 | 40 | get_community_edge_save_query, |
37 | 41 | get_entity_edge_return_query, |
38 | 42 | get_entity_edge_save_query, |
@@ -561,6 +565,200 @@ async def get_by_group_ids( |
561 | 565 | return edges |
562 | 566 |
|
563 | 567 |
|
| 568 | +class HasEpisodeEdge(Edge): |
| 569 | + async def save(self, driver: GraphDriver): |
| 570 | + result = await driver.execute_query( |
| 571 | + HAS_EPISODE_EDGE_SAVE, |
| 572 | + saga_uuid=self.source_node_uuid, |
| 573 | + episode_uuid=self.target_node_uuid, |
| 574 | + uuid=self.uuid, |
| 575 | + group_id=self.group_id, |
| 576 | + created_at=self.created_at, |
| 577 | + ) |
| 578 | + |
| 579 | + logger.debug(f'Saved edge to Graph: {self.uuid}') |
| 580 | + |
| 581 | + return result |
| 582 | + |
| 583 | + async def delete(self, driver: GraphDriver): |
| 584 | + await driver.execute_query( |
| 585 | + """ |
| 586 | + MATCH (n:Saga)-[e:HAS_EPISODE {uuid: $uuid}]->(m:Episodic) |
| 587 | + DELETE e |
| 588 | + """, |
| 589 | + uuid=self.uuid, |
| 590 | + ) |
| 591 | + |
| 592 | + logger.debug(f'Deleted Edge: {self.uuid}') |
| 593 | + |
| 594 | + @classmethod |
| 595 | + async def get_by_uuid(cls, driver: GraphDriver, uuid: str): |
| 596 | + records, _, _ = await driver.execute_query( |
| 597 | + """ |
| 598 | + MATCH (n:Saga)-[e:HAS_EPISODE {uuid: $uuid}]->(m:Episodic) |
| 599 | + RETURN |
| 600 | + """ |
| 601 | + + HAS_EPISODE_EDGE_RETURN, |
| 602 | + uuid=uuid, |
| 603 | + routing_='r', |
| 604 | + ) |
| 605 | + |
| 606 | + edges = [get_has_episode_edge_from_record(record) for record in records] |
| 607 | + |
| 608 | + if len(edges) == 0: |
| 609 | + raise EdgeNotFoundError(uuid) |
| 610 | + return edges[0] |
| 611 | + |
| 612 | + @classmethod |
| 613 | + async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]): |
| 614 | + records, _, _ = await driver.execute_query( |
| 615 | + """ |
| 616 | + MATCH (n:Saga)-[e:HAS_EPISODE]->(m:Episodic) |
| 617 | + WHERE e.uuid IN $uuids |
| 618 | + RETURN |
| 619 | + """ |
| 620 | + + HAS_EPISODE_EDGE_RETURN, |
| 621 | + uuids=uuids, |
| 622 | + routing_='r', |
| 623 | + ) |
| 624 | + |
| 625 | + edges = [get_has_episode_edge_from_record(record) for record in records] |
| 626 | + |
| 627 | + return edges |
| 628 | + |
| 629 | + @classmethod |
| 630 | + async def get_by_group_ids( |
| 631 | + cls, |
| 632 | + driver: GraphDriver, |
| 633 | + group_ids: list[str], |
| 634 | + limit: int | None = None, |
| 635 | + uuid_cursor: str | None = None, |
| 636 | + ): |
| 637 | + cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else '' |
| 638 | + limit_query: LiteralString = 'LIMIT $limit' if limit is not None else '' |
| 639 | + |
| 640 | + records, _, _ = await driver.execute_query( |
| 641 | + """ |
| 642 | + MATCH (n:Saga)-[e:HAS_EPISODE]->(m:Episodic) |
| 643 | + WHERE e.group_id IN $group_ids |
| 644 | + """ |
| 645 | + + cursor_query |
| 646 | + + """ |
| 647 | + RETURN |
| 648 | + """ |
| 649 | + + HAS_EPISODE_EDGE_RETURN |
| 650 | + + """ |
| 651 | + ORDER BY e.uuid DESC |
| 652 | + """ |
| 653 | + + limit_query, |
| 654 | + group_ids=group_ids, |
| 655 | + uuid=uuid_cursor, |
| 656 | + limit=limit, |
| 657 | + routing_='r', |
| 658 | + ) |
| 659 | + |
| 660 | + edges = [get_has_episode_edge_from_record(record) for record in records] |
| 661 | + |
| 662 | + return edges |
| 663 | + |
| 664 | + |
| 665 | +class NextEpisodeEdge(Edge): |
| 666 | + async def save(self, driver: GraphDriver): |
| 667 | + result = await driver.execute_query( |
| 668 | + NEXT_EPISODE_EDGE_SAVE, |
| 669 | + source_episode_uuid=self.source_node_uuid, |
| 670 | + target_episode_uuid=self.target_node_uuid, |
| 671 | + uuid=self.uuid, |
| 672 | + group_id=self.group_id, |
| 673 | + created_at=self.created_at, |
| 674 | + ) |
| 675 | + |
| 676 | + logger.debug(f'Saved edge to Graph: {self.uuid}') |
| 677 | + |
| 678 | + return result |
| 679 | + |
| 680 | + async def delete(self, driver: GraphDriver): |
| 681 | + await driver.execute_query( |
| 682 | + """ |
| 683 | + MATCH (n:Episodic)-[e:NEXT_EPISODE {uuid: $uuid}]->(m:Episodic) |
| 684 | + DELETE e |
| 685 | + """, |
| 686 | + uuid=self.uuid, |
| 687 | + ) |
| 688 | + |
| 689 | + logger.debug(f'Deleted Edge: {self.uuid}') |
| 690 | + |
| 691 | + @classmethod |
| 692 | + async def get_by_uuid(cls, driver: GraphDriver, uuid: str): |
| 693 | + records, _, _ = await driver.execute_query( |
| 694 | + """ |
| 695 | + MATCH (n:Episodic)-[e:NEXT_EPISODE {uuid: $uuid}]->(m:Episodic) |
| 696 | + RETURN |
| 697 | + """ |
| 698 | + + NEXT_EPISODE_EDGE_RETURN, |
| 699 | + uuid=uuid, |
| 700 | + routing_='r', |
| 701 | + ) |
| 702 | + |
| 703 | + edges = [get_next_episode_edge_from_record(record) for record in records] |
| 704 | + |
| 705 | + if len(edges) == 0: |
| 706 | + raise EdgeNotFoundError(uuid) |
| 707 | + return edges[0] |
| 708 | + |
| 709 | + @classmethod |
| 710 | + async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]): |
| 711 | + records, _, _ = await driver.execute_query( |
| 712 | + """ |
| 713 | + MATCH (n:Episodic)-[e:NEXT_EPISODE]->(m:Episodic) |
| 714 | + WHERE e.uuid IN $uuids |
| 715 | + RETURN |
| 716 | + """ |
| 717 | + + NEXT_EPISODE_EDGE_RETURN, |
| 718 | + uuids=uuids, |
| 719 | + routing_='r', |
| 720 | + ) |
| 721 | + |
| 722 | + edges = [get_next_episode_edge_from_record(record) for record in records] |
| 723 | + |
| 724 | + return edges |
| 725 | + |
| 726 | + @classmethod |
| 727 | + async def get_by_group_ids( |
| 728 | + cls, |
| 729 | + driver: GraphDriver, |
| 730 | + group_ids: list[str], |
| 731 | + limit: int | None = None, |
| 732 | + uuid_cursor: str | None = None, |
| 733 | + ): |
| 734 | + cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else '' |
| 735 | + limit_query: LiteralString = 'LIMIT $limit' if limit is not None else '' |
| 736 | + |
| 737 | + records, _, _ = await driver.execute_query( |
| 738 | + """ |
| 739 | + MATCH (n:Episodic)-[e:NEXT_EPISODE]->(m:Episodic) |
| 740 | + WHERE e.group_id IN $group_ids |
| 741 | + """ |
| 742 | + + cursor_query |
| 743 | + + """ |
| 744 | + RETURN |
| 745 | + """ |
| 746 | + + NEXT_EPISODE_EDGE_RETURN |
| 747 | + + """ |
| 748 | + ORDER BY e.uuid DESC |
| 749 | + """ |
| 750 | + + limit_query, |
| 751 | + group_ids=group_ids, |
| 752 | + uuid=uuid_cursor, |
| 753 | + limit=limit, |
| 754 | + routing_='r', |
| 755 | + ) |
| 756 | + |
| 757 | + edges = [get_next_episode_edge_from_record(record) for record in records] |
| 758 | + |
| 759 | + return edges |
| 760 | + |
| 761 | + |
564 | 762 | # Edge helpers |
565 | 763 | def get_episodic_edge_from_record(record: Any) -> EpisodicEdge: |
566 | 764 | return EpisodicEdge( |
@@ -620,6 +818,26 @@ def get_community_edge_from_record(record: Any): |
620 | 818 | ) |
621 | 819 |
|
622 | 820 |
|
| 821 | +def get_has_episode_edge_from_record(record: Any) -> HasEpisodeEdge: |
| 822 | + return HasEpisodeEdge( |
| 823 | + uuid=record['uuid'], |
| 824 | + group_id=record['group_id'], |
| 825 | + source_node_uuid=record['source_node_uuid'], |
| 826 | + target_node_uuid=record['target_node_uuid'], |
| 827 | + created_at=parse_db_date(record['created_at']), # type: ignore |
| 828 | + ) |
| 829 | + |
| 830 | + |
| 831 | +def get_next_episode_edge_from_record(record: Any) -> NextEpisodeEdge: |
| 832 | + return NextEpisodeEdge( |
| 833 | + uuid=record['uuid'], |
| 834 | + group_id=record['group_id'], |
| 835 | + source_node_uuid=record['source_node_uuid'], |
| 836 | + target_node_uuid=record['target_node_uuid'], |
| 837 | + created_at=parse_db_date(record['created_at']), # type: ignore |
| 838 | + ) |
| 839 | + |
| 840 | + |
623 | 841 | async def create_entity_edge_embeddings(embedder: EmbedderClient, edges: list[EntityEdge]): |
624 | 842 | # filter out falsey values from edges |
625 | 843 | filtered_edges = [edge for edge in edges if edge.fact] |
|
0 commit comments