Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 82 additions & 89 deletions libs/knowledge-store/ragstack_knowledge_store/graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,8 @@

CONTENT_ID = "content_id"

CONTENT_COLUMNS = "content_id, text_content, metadata_blob"

SELECT_CQL_TEMPLATE = (
"SELECT {columns} FROM {table_name}{where_clause}{order_clause}{limit_clause};"
"SELECT {columns} FROM {table_name}{where_clause}{order_clause}{limit_clause}"
)


Expand Down Expand Up @@ -221,13 +219,13 @@ def __init__(

self._query_by_id = session.prepare(
f"""
SELECT {CONTENT_COLUMNS}
SELECT content_id, text_content, metadata_blob
FROM {keyspace}.{node_table}
WHERE content_id = ?
""" # noqa: S608
)

self._query_ids_and_link_to_tags_by_id = session.prepare(
self._query_id_and_metadata_by_id = session.prepare(
f"""
SELECT content_id, metadata_blob
FROM {keyspace}.{node_table}
Expand All @@ -247,10 +245,8 @@ def _apply_schema(self) -> None:
content_id TEXT,
text_content TEXT,
text_embedding VECTOR<FLOAT, {embedding_dim}>,

metadata_blob TEXT,
metadata_s MAP<TEXT,TEXT>,

PRIMARY KEY (content_id)
)
""")
Expand Down Expand Up @@ -279,36 +275,36 @@ def add_nodes(
"""Add nodes to the graph store."""
node_ids: list[str] = []
texts: list[str] = []
metadatas: list[dict[str, Any]] = []
nodes_links: list[set[Link]] = []
metadata_list: list[dict[str, Any]] = []
incoming_links_list: list[set[Link]] = []
for node in nodes:
if not node.id:
node_ids.append(secrets.token_hex(8))
else:
node_ids.append(node.id)
texts.append(node.text)
metadatas.append(node.metadata)
nodes_links.append(node.links)
combined_metadata = node.metadata.copy()
combined_metadata["links"] = _serialize_links(node.links)
metadata_list.append(combined_metadata)
incoming_links_list.append(node.incoming_links())

text_embeddings = self._embedding.embed_texts(texts)

with self._concurrent_queries() as cq:
tuples = zip(node_ids, texts, text_embeddings, metadatas, nodes_links)
for node_id, text, text_embedding, metadata, links in tuples:
link_to_tags = set() # link to these tags
link_from_tags = set() # link from these tags
tuples = zip(node_ids, texts, text_embeddings, metadata_list, incoming_links_list)
for node_id, text, text_embedding, metadata, incoming_links in tuples:

metadata_s = {
k: self._coerce_string(v)
for k, v in metadata.items()
if _is_metadata_field_indexed(k, self._metadata_indexing_policy)
}

for tag in links:
if tag.direction in {"in", "bidir"}:
metadata_s[_metadata_s_link_key(link=tag)] =_metadata_s_link_value()
for incoming_link in incoming_links:
metadata_s[_metadata_s_link_key(link=incoming_link)] =_metadata_s_link_value()

metadata_blob = _serialize_metadata(metadata)

cq.execute(
self._insert_passage,
parameters=(
Expand Down Expand Up @@ -406,60 +402,58 @@ def mmr_traversal_search(
score_threshold=score_threshold,
)

# For each unselected node, stores the outgoing tags.
outgoing_links: dict[str, set[Link]] = {}
# For each unselected node, stores the outgoing links.
outgoing_links_map: dict[str, set[Link]] = {}
visited_links: set[Link] = set()


def fetch_neighborhood(neighborhood: Sequence[str]) -> None:
nonlocal outgoing_links
nonlocal outgoing_links_map
nonlocal visited_links

# Put the neighborhood into the outgoing tags, to avoid adding it
# Put the neighborhood into the outgoing links, to avoid adding it
# to the candidate set in the future.
outgoing_links.update({content_id: set() for content_id in neighborhood})
outgoing_links_map.update({content_id: set() for content_id in neighborhood})

# Initialize the visited_links with the set of outgoing from the
# Initialize the visited_links with the set of outgoing links from the
# neighborhood. This prevents re-visiting them.
visited_links = self._get_outgoing_links(neighborhood)

# Call `self._get_adjacent` to fetch the candidates.
adjacent_nodes = self._get_adjacent(
links=visited_links,
query_embedding=query_embedding,
k_per_tag=adjacent_k,
k_per_link=adjacent_k,
metadata_filter=metadata_filter,
)

new_candidates = {}
new_candidates: dict[str, list[float]] = {}
for adjacent_node in adjacent_nodes:
if adjacent_node.id not in outgoing_links:
outgoing_links[adjacent_node.id] = (
adjacent_node.outgoing_links()
)

new_candidates[adjacent_node.id] = (
adjacent_node.embedding
)
if adjacent_node.id not in outgoing_links_map:
outgoing_links_map[adjacent_node.id] = adjacent_node.outgoing_links()
new_candidates[adjacent_node.id] = adjacent_node.embedding
helper.add_candidates(new_candidates)

def fetch_initial_candidates() -> None:
nonlocal outgoing_links_map
nonlocal visited_links

initial_candidates_query, params = self._get_search_cql_and_params(
columns = "content_id, text_embedding, metadata_blob",
limit=fetch_k,
metadata=metadata_filter,
embedding=query_embedding,
)

fetched = self._session.execute(
rows = self._session.execute(
query=initial_candidates_query, parameters=params
)
candidates = {}
for row in fetched:
if row.content_id not in outgoing_links:
candidates: dict[str, list[float]] = {}
for row in rows:
if row.content_id not in outgoing_links_map:
node = _row_to_node(row=row)
outgoing_links_map[node.id] = node.outgoing_links()
candidates[node.id] = node.embedding
outgoing_links[node.id] = node.outgoing_links()
helper.add_candidates(candidates)

if initial_roots:
Expand All @@ -482,39 +476,33 @@ def fetch_initial_candidates() -> None:
# If the next nodes would not exceed the depth limit, find the
# adjacent nodes.
#
# TODO: For a big performance win, we should track which tags we've
# TODO: For a big performance win, we should track which links we've
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like this may be a stale comment. Specifically, the difference_update on line 487 seems to be doing this.

# already incorporated. We don't need to issue adjacent queries for
# those.

# Find the tags linked to from the selected ID.
link_to_tags = outgoing_links.pop(selected_id)
# Find the links linked to from the selected ID.
selected_outgoing_links = outgoing_links_map.pop(selected_id)

# Don't re-visit already visited tags.
link_to_tags.difference_update(visited_links)
# Don't re-visit already visited links.
selected_outgoing_links.difference_update(visited_links)

# Find the nodes with incoming links from those tags.
# Find the nodes with incoming links from those links.
adjacent_nodes = self._get_adjacent(
links=link_to_tags,
links=selected_outgoing_links,
query_embedding=query_embedding,
k_per_tag=adjacent_k,
k_per_link=adjacent_k,
metadata_filter=metadata_filter,
)

# Record the link_to_tags as visited.
visited_links.update(link_to_tags)
# Record the selected_outgoing_links as visited.
visited_links.update(selected_outgoing_links)

new_candidates = {}
for adjacent_node in adjacent_nodes:
if adjacent_node.id not in outgoing_links:
outgoing_links[adjacent_node.id] = (
adjacent_node.outgoing_links()
)
new_candidates[adjacent_node.id] = (
adjacent_node.embedding
)
if next_depth < depths.get(
adjacent_node.id, depth + 1
):
if adjacent_node.id not in outgoing_links_map:
outgoing_links_map[adjacent_node.id] = adjacent_node.outgoing_links()
new_candidates[adjacent_node.id] = adjacent_node.embedding
if next_depth < depths.get(adjacent_node.id, depth + 1):
# If this is a new shortest depth, or there was no
# previous depth, update the depths. This ensures that
# when we discover a node we will have the shortest
Expand Down Expand Up @@ -556,12 +544,12 @@ def traversal_search(
"""
# Depth 0:
# Query for `k` nodes similar to the question.
# Retrieve `content_id` and `link_to_tags`.
# Retrieve `content_id` and `outgoing_links()`.
#
# Depth 1:
# Query for nodes that have an incoming tag in the `link_to_tags` set.
# Query for nodes that have an incoming link in the `outgoing_links()` set.
# Combine node IDs.
# Query for `link_to_tags` of those "new" node IDs.
# Query for `outgoing_links()` of those "new" node IDs.
#
# ...

Expand All @@ -572,19 +560,18 @@ def traversal_search(
# Map from visited ID to depth
visited_ids: dict[str, int] = {}

# Map from visited tag `(kind, tag)` to depth. Allows skipping queries
# for tags that we've already traversed.
# Map from visited link to depth. Allows skipping queries
# for links that we've already traversed.
visited_links: dict[Link, int] = {}

def visit_nodes(d: int, rows: Sequence[Any]) -> None:
nonlocal visited_ids
nonlocal visited_links

# Visit nodes at the given depth.
# Each node has `content_id` and `link_to_tags`.

# Iterate over nodes, tracking the *new* outgoing kind tags for this
# depth. This is tags that are either new, or newly discovered at a
# Iterate over nodes, tracking the *new* outgoing links for this
# depth. These are links that are either new, or newly discovered at a
# lower depth.
outgoing_links: Set[Link] = set()
for row in rows:
Expand All @@ -598,17 +585,17 @@ def visit_nodes(d: int, rows: Sequence[Any]) -> None:
if d < depth:
node = _row_to_node(row=row)
# Record any new (or newly discovered at a lower depth)
# tags to the set to traverse.
# links to the set to traverse.
for link in node.outgoing_links():
if d <= visited_links.get(link, depth):
# Record that we'll query this tag at the
# Record that we'll query this link at the
# given depth, so we don't fetch it again
# (unless we find it an earlier depth)
visited_links[link] = d
outgoing_links.add(link)

if outgoing_links:
# If there are new tags to visit at the next depth, query for the
# If there are new links to visit at the next depth, query for the
# node IDs.
for outgoing_link in outgoing_links:
visit_nodes_query, params = self._get_search_cql_and_params(
Expand All @@ -622,20 +609,19 @@ def visit_nodes(d: int, rows: Sequence[Any]) -> None:
callback=lambda rows, d=d: visit_targets(d, rows),
)

def visit_targets(d: int, targets: Sequence[Any]) -> None:
def visit_targets(d: int, rows: Sequence[Any]) -> None:
nonlocal visited_ids

# target_content_id, tag=(kind,value)
new_nodes_at_next_depth = set()
for target in targets:
content_id = target.target_content_id
new_node_ids_at_next_depth = set()
for row in rows:
content_id = row.target_content_id
if d < visited_ids.get(content_id, depth):
new_nodes_at_next_depth.add(content_id)
new_node_ids_at_next_depth.add(content_id)

if new_nodes_at_next_depth:
for node_id in new_nodes_at_next_depth:
if new_node_ids_at_next_depth:
for node_id in new_node_ids_at_next_depth:
cq.execute(
self._query_ids_and_link_to_tags_by_id,
self._query_id_and_metadata_by_id,
parameters=(node_id,),
callback=lambda rows, d=d: visit_nodes(d + 1, rows),
)
Expand Down Expand Up @@ -663,7 +649,10 @@ def similarity_search(
) -> Iterable[Node]:
"""Retrieve nodes similar to the given embedding, optionally filtered by metadata.""" # noqa: E501
query, params = self._get_search_cql_and_params(
embedding=embedding, limit=k, metadata=metadata_filter
columns=f"{CONTENT_ID}, text_content, metadata_blob",
embedding=embedding,
limit=k,
metadata=metadata_filter,
)

for row in self._session.execute(query, params):
Expand All @@ -675,7 +664,11 @@ def metadata_search(
n: int = 5,
) -> Iterable[Node]:
"""Retrieve nodes based on their metadata."""
query, params = self._get_search_cql_and_params(metadata=metadata, limit=n)
query, params = self._get_search_cql_and_params(
columns=f"{CONTENT_ID}, text_content, metadata_blob",
metadata=metadata,
limit=n,
)

for row in self._session.execute(query, params):
yield _row_to_node(row)
Expand All @@ -688,10 +681,10 @@ def _get_outgoing_links(
self,
source_ids: Iterable[str],
) -> set[Link]:
"""Return the set of outgoing tags for the given source ID(s).
"""Return the set of outgoing links for the given source ID(s).

Args:
source_ids: The IDs of the source nodes to retrieve outgoing tags for.
source_ids: The IDs of the source nodes to retrieve outgoing links for.
"""
links = set()

Expand All @@ -703,7 +696,7 @@ def add_sources(rows: Iterable[Any]) -> None:
with self._concurrent_queries() as cq:
for source_id in source_ids:
cq.execute(
self._query_ids_and_link_to_tags_by_id,
self._query_id_and_metadata_by_id,
(source_id,),
callback=add_sources,
)
Expand All @@ -714,15 +707,15 @@ def _get_adjacent(
self,
links: set[Link],
query_embedding: list[float],
k_per_tag: int | None = None,
k_per_link: int | None = None,
metadata_filter: dict[str, Any] | None = None,
) -> Iterable[Node]:
"""Return the target nodes with incoming links from any of the given tags.
"""Return the target nodes with incoming links from any of the given links.

Args:
tags: The tags to look for links *from*.
links: The links to look for.
query_embedding: The query embedding. Used to rank target nodes.
k_per_tag: The number of target nodes to fetch for each outgoing tag.
k_per_link: The number of target nodes to fetch for each link.
metadata_filter: Optional metadata to filter the results.

Returns:
Expand All @@ -745,7 +738,7 @@ def add_targets(rows: Iterable[Any]) -> None:
for link in links:
adjacent_query, params = self._get_search_cql_and_params(
columns = "content_id, text_embedding, metadata_blob",
limit=k_per_tag or 10,
limit=k_per_link or 10,
metadata=metadata_filter,
embedding=query_embedding,
outgoing_link=link,
Expand Down
Loading