diff --git a/milvus/docker-compose.yml b/milvus/docker-compose.yml
new file mode 100644
index 0000000000..3589060510
--- /dev/null
+++ b/milvus/docker-compose.yml
@@ -0,0 +1,26 @@
+version: "3.9"
+
+services:
+ milvus:
+ image: milvusdb/milvus:v2.6.1
+ container_name: milvus
+ command: ["milvus", "run", "standalone"]
+ security_opt:
+ - seccomp:unconfined
+ environment:
+ ETCD_USE_EMBED: "true"
+ COMMON_STORAGETYPE: "local"
+ DEPLOY_MODE: "STANDALONE"
+ ports:
+ - "9091:9091"
+ - "19530:19530"
+
+ attu:
+ image: zilliz/attu:v2.6
+ container_name: attu
+ environment:
+ MILVUS_URL: "http://milvus:19530"
+ ports:
+ - "8000:3000"
+ depends_on:
+ - milvus
diff --git a/python/python/raphtory/graphql/__init__.pyi b/python/python/raphtory/graphql/__init__.pyi
index a95c9f6125..8f433b7225 100644
--- a/python/python/raphtory/graphql/__init__.pyi
+++ b/python/python/raphtory/graphql/__init__.pyi
@@ -93,26 +93,6 @@ class GraphServer(object):
None:
"""
- def set_embeddings(
- self,
- cache: str,
- embedding: Optional[Callable] = None,
- nodes: bool | str = True,
- edges: bool | str = True,
- ) -> GraphServer:
- """
- Setup the server to vectorise graphs with a default template.
-
- Arguments:
- cache (str): the directory to use as cache for the embeddings.
- embedding (Callable, optional): the embedding function to translate documents to embeddings.
- nodes (bool | str): if nodes have to be embedded or not or the custom template to use if a str is provided. Defaults to True.
- edges (bool | str): if edges have to be embedded or not or the custom template to use if a str is provided. Defaults to True.
-
- Returns:
- GraphServer: A new server object with embeddings setup.
- """
-
def start(self, port: int = 1736, timeout_ms: int = 5000) -> RunningGraphServer:
"""
Start the server and return a handle to it.
@@ -127,22 +107,21 @@ class GraphServer(object):
RunningGraphServer: The running server
"""
- def turn_off_index(self) -> GraphServer:
- """
- Turn off index for all graphs
+ def turn_off_index(self):
+ """Turn off index for all graphs"""
- Returns:
- GraphServer: The server with indexing disabled
- """
-
- def with_vectorised_graphs(
- self, graph_names: list[str], nodes: bool | str = True, edges: bool | str = True
+ def vectorise_graph(
+ self,
+ name: list[str],
+ embeddings,
+ nodes: bool | str = True,
+ edges: bool | str = True,
) -> GraphServer:
"""
- Vectorise a subset of the graphs of the server.
+ Vectorise the graph name in the server working directory.
Arguments:
- graph_names (list[str]): the names of the graphs to vectorise. All by default.
+ name (list[str]): the name of the graph to vectorise.
nodes (bool | str): if nodes have to be embedded or not or the custom template to use if a str is provided. Defaults to True.
edges (bool | str): if edges have to be embedded or not or the custom template to use if a str is provided. Defaults to True.
diff --git a/python/python/raphtory/vectors/__init__.pyi b/python/python/raphtory/vectors/__init__.pyi
index c9011c1357..73b98cb4b5 100644
--- a/python/python/raphtory/vectors/__init__.pyi
+++ b/python/python/raphtory/vectors/__init__.pyi
@@ -25,7 +25,14 @@ import networkx as nx # type: ignore
import pyvis # type: ignore
from raphtory.iterables import *
-__all__ = ["VectorisedGraph", "Document", "Embedding", "VectorSelection"]
+__all__ = [
+ "VectorisedGraph",
+ "Document",
+ "Embedding",
+ "VectorSelection",
+ "OpenAIEmbeddings",
+ "embedding_server",
+]
class VectorisedGraph(object):
"""VectorisedGraph object that contains embedded documents that correspond to graph entities."""
@@ -37,10 +44,10 @@ class VectorisedGraph(object):
window: Optional[Tuple[int | str, int | str]] = None,
) -> VectorSelection:
"""
- Perform a similarity search between each edge's associated document and a specified `query`. Returns a number of edges up to a specified `limit` ranked in descending order of similarity score.
+ Perform a similarity search between each edge's associated document and a specified `query`. Returns a number of edges up to a specified `limit` ranked in ascending order of distance.
Args:
- query (str | list): The text or the embedding to score against.
+ query (str | list): The text or the embedding to calculate the distance from.
limit (int): The maximum number of new edges in the results.
window (Tuple[int | str, int | str], optional): The window that documents need to belong to in order to be considered.
@@ -58,10 +65,10 @@ class VectorisedGraph(object):
window: Optional[Tuple[int | str, int | str]] = None,
) -> VectorSelection:
"""
- Perform a similarity search between each entity's associated document and a specified `query`. Returns a number of entities up to a specified `limit` ranked in descending order of similarity score.
+ Perform a similarity search between each entity's associated document and a specified `query`. Returns a number of entities up to a specified `limit` ranked in ascending order of distance.
Args:
- query (str | list): The text or the embedding to score against.
+ query (str | list): The text or the embedding to calculate the distance from.
limit (int): The maximum number of new entities in the result.
window (Tuple[int | str, int | str], optional): The window that documents need to belong to in order to be considered.
@@ -76,17 +83,20 @@ class VectorisedGraph(object):
window: Optional[Tuple[int | str, int | str]] = None,
) -> VectorSelection:
"""
- Perform a similarity search between each node's associated document and a specified `query`. Returns a number of nodes up to a specified `limit` ranked in descending order of similarity score.
+ Perform a similarity search between each node's associated document and a specified `query`. Returns a number of nodes up to a specified `limit` ranked in ascending order of distance.
Args:
- query (str | list): The text or the embedding to score against.
+ query (str | list): The text or the embedding to calculate the distance from.
limit (int): The maximum number of new nodes in the result.
- window (Tuple[int | str, int | str], optional): The window that documents need to belong to in order to be considered.
+ window (Tuple[int | str, int | str], optional): The window where documents need to belong to in order to be considered.
Returns:
VectorSelection: The vector selection resulting from the search.
"""
+ def optimize_index(self):
+ """Optmize the vector index"""
+
class Document(object):
"""A document corresponding to a graph entity. Used to generate embeddings."""
@@ -129,7 +139,7 @@ class VectorSelection(object):
"""
Add all the documents associated with the specified `edges` to the current selection.
- Documents added by this call are assumed to have a score of 0.
+ Documents added by this call are assumed to have a distance of 0.
Args:
edges (list): List of the edge ids or edges to add.
@@ -142,7 +152,7 @@ class VectorSelection(object):
"""
Add all the documents associated with the specified `nodes` to the current selection.
- Documents added by this call are assumed to have a score of 0.
+ Documents added by this call are assumed to have a distance of 0.
Args:
nodes (list): List of the node ids or nodes to add.
@@ -196,12 +206,12 @@ class VectorSelection(object):
window: Optional[Tuple[int | str, int | str]] = None,
) -> None:
"""
- Add the top `limit` adjacent edges with higher score for `query` to the selection
+ Add to the selection the `limit` adjacent edges closest to `query`
This function has the same behaviour as expand_entities_by_similarity but it only considers edges.
Args:
- query (str | list): The text or the embedding to score against.
+ query (str | list): The text or the embedding to calculate the distance from.
limit (int): The maximum number of new edges to add.
window (Tuple[int | str, int | str], optional): The window that documents need to belong to in order to be considered.
@@ -216,20 +226,19 @@ class VectorSelection(object):
window: Optional[Tuple[int | str, int | str]] = None,
) -> None:
"""
- Add the top `limit` adjacent entities with higher score for `query` to the selection
+ Add to the selection the `limit` adjacent entities closest to `query`
The expansion algorithm is a loop with two steps on each iteration:
1. All the entities 1 hop away of some of the entities included on the selection (and
not already selected) are marked as candidates.
- 2. Those candidates are added to the selection in descending order according to the
- similarity score obtained against the `query`.
+ 2. Those candidates are added to the selection in ascending distance from `query`.
This loops goes on until the number of new entities reaches a total of `limit`
entities or until no more documents are available
Args:
- query (str | list): The text or the embedding to score against.
+ query (str | list): The text or the embedding to calculate the distance from.
limit (int): The number of documents to add.
window (Tuple[int | str, int | str], optional): The window that documents need to belong to in order to be considered.
@@ -244,12 +253,12 @@ class VectorSelection(object):
window: Optional[Tuple[int | str, int | str]] = None,
) -> None:
"""
- Add the top `limit` adjacent nodes with higher score for `query` to the selection
+ Add to the selection the `limit` adjacent nodes closest to `query`
This function has the same behaviour as expand_entities_by_similarity but it only considers nodes.
Args:
- query (str | list): The text or the embedding to score against.
+ query (str | list): The text or the embedding to calculate the distance from.
limit (int): The maximum number of new nodes to add.
window (Tuple[int | str, int | str], optional): The window that documents need to belong to in order to be considered.
@@ -265,12 +274,12 @@ class VectorSelection(object):
list[Document]: List of documents in the current selection.
"""
- def get_documents_with_scores(self) -> list[Tuple[Document, float]]:
+ def get_documents_with_distances(self) -> list[Tuple[Document, float]]:
"""
- Returns the documents present in the current selection alongside their scores.
+ Returns the documents present in the current selection alongside their distances.
Returns:
- list[Tuple[Document, float]]: List of documents and scores.
+ list[Tuple[Document, float]]: List of documents and distances.
"""
def nodes(self) -> list[Node]:
@@ -280,3 +289,16 @@ class VectorSelection(object):
Returns:
list[Node]: List of nodes in the current selection.
"""
+
+class OpenAIEmbeddings(object):
+ def __new__(
+ cls,
+ model="text-embedding-3-small",
+ api_base=None,
+ api_key_env=None,
+ org_id=None,
+ project_id=None,
+ ) -> OpenAIEmbeddings:
+ """Create and return a new object. See help(type) for accurate signature."""
+
+def embedding_server(address): ...
diff --git a/python/tests/test_base_install/test_graphql/misc/test_graphql_vectors.py b/python/tests/test_base_install/test_graphql/misc/test_graphql_vectors.py
index 438c9ad31a..f78184e465 100644
--- a/python/tests/test_base_install/test_graphql/misc/test_graphql_vectors.py
+++ b/python/tests/test_base_install/test_graphql/misc/test_graphql_vectors.py
@@ -1,15 +1,12 @@
import tempfile
from raphtory.graphql import GraphServer, RaphtoryClient
from raphtory import Graph
+from raphtory.vectors import OpenAIEmbeddings, embedding_server
-def embedding(texts):
- return [[text.count("a"), text.count("b")] for text in texts]
-
-
-def test_embedding():
- result = embedding(texts=["aaa", "b", "ab", "ba"])
- assert result == [[3, 0], [0, 1], [1, 1], [1, 1]]
+@embedding_server(address="0.0.0.0:7340")
+def embeddings(text: str):
+ return [text.count("a"), text.count("b")]
def setup_graph(g):
@@ -58,51 +55,67 @@ def assert_correct_documents(client):
}
-def setup_server(work_dir):
- server = GraphServer(work_dir)
- server = server.set_embeddings(
- cache="/tmp/graph-cache",
- embedding=embedding,
- nodes="{{ name }}",
- edges=False,
- )
- return server
-
-
def test_new_graph():
print("test_new_graph")
work_dir = tempfile.TemporaryDirectory()
- server = setup_server(work_dir.name)
- with server.start():
- client = RaphtoryClient("http://localhost:1736")
- client.new_graph("abb", "EVENT")
- rg = client.remote_graph("abb")
- setup_graph(rg)
- assert_correct_documents(client)
+ server = GraphServer(work_dir.name)
+ with embeddings.start():
+ with server.start():
+ client = RaphtoryClient("http://localhost:1736")
+ client.new_graph("abb", "EVENT")
+ rg = client.remote_graph("abb")
+ setup_graph(rg)
+ client.query(
+ """
+ {
+ vectoriseGraph(path: "abb", model: { openAI: { model: "whatever", apiBase: "http://localhost:7340" } }, nodes: { custom: "{{ name }}" }, edges: { enabled: false })
+ }
+ """
+ )
+ assert_correct_documents(client)
def test_upload_graph():
print("test_upload_graph")
- work_dir = tempfile.mkdtemp()
- temp_dir = tempfile.mkdtemp()
- server = setup_server(work_dir)
- with server.start():
- client = RaphtoryClient("http://localhost:1736")
- g = Graph()
- setup_graph(g)
- g_path = temp_dir + "/abb"
- g.save_to_zip(g_path)
- client.upload_graph(path="abb", file_path=g_path, overwrite=True)
- assert_correct_documents(client)
+ work_dir = tempfile.TemporaryDirectory()
+ temp_dir = tempfile.TemporaryDirectory()
+ server = GraphServer(work_dir.name)
+ with embeddings.start():
+ with server.start():
+ client = RaphtoryClient("http://localhost:1736")
+ g = Graph()
+ setup_graph(g)
+ g_path = temp_dir.name + "/abb"
+ g.save_to_zip(g_path)
+ client.upload_graph(path="abb", file_path=g_path, overwrite=True)
+ client.query(
+ """
+ {
+ vectoriseGraph(path: "abb", model: { openAI: { model: "whatever", apiBase: "http://localhost:7340" } }, nodes: { custom: "{{ name }}" }, edges: { enabled: false })
+ }
+ """
+ )
+ assert_correct_documents(client)
+
+
+GRAPH_NAME = "abb"
def test_include_graph():
- work_dir = tempfile.mkdtemp()
- g_path = work_dir + "/abb"
+ work_dir = tempfile.TemporaryDirectory()
+ g_path = work_dir.name + "/" + GRAPH_NAME
g = Graph()
setup_graph(g)
g.save_to_file(g_path)
- server = setup_server(work_dir)
- with server.start():
- client = RaphtoryClient("http://localhost:1736")
- assert_correct_documents(client)
+ server = GraphServer(work_dir.name)
+ with embeddings.start():
+ embedding_client = OpenAIEmbeddings(api_base="http://localhost:7340")
+ server.vectorise_graph(
+ name=GRAPH_NAME,
+ embeddings=embedding_client,
+ nodes="{{ name }}",
+ edges=False,
+ )
+ with server.start():
+ client = RaphtoryClient("http://localhost:1736")
+ assert_correct_documents(client)
diff --git a/python/tests/test_base_install/test_vectors.py b/python/tests/test_base_install/test_vectors.py
index 9eb455eae4..91a75d071c 100644
--- a/python/tests/test_base_install/test_vectors.py
+++ b/python/tests/test_base_install/test_vectors.py
@@ -1,7 +1,12 @@
+import pytest
+import json
+from urllib.request import Request, urlopen
+from urllib.error import HTTPError
from raphtory import Graph
-from raphtory.vectors import VectorisedGraph
+from raphtory.vectors import VectorisedGraph, OpenAIEmbeddings, embedding_server
embedding_map = {
+ "raphtory": [1.0, 0.0, 0.0], # this is now needed,
"node1": [1.0, 0.0, 0.0],
"node2": [0.0, 1.0, 0.0],
"node3": [0.0, 0.0, 1.0],
@@ -12,19 +17,46 @@
}
-def single_embedding(text: str):
- try:
+@pytest.fixture(autouse=True)
+def test_server():
+ @embedding_server(address="0.0.0.0:7340") # TODO: ask only for PORT!!!
+ def custom_embeddings(text: str):
return embedding_map[text]
- except:
- raise Exception(f"unexpected document content: {text}")
+ with custom_embeddings.start():
+ yield
+
+
+def post_json(url, payload):
+ data = json.dumps(payload).encode()
+ req = Request(
+ url,
+ data=data,
+ headers={"Content-Type": "application/json"},
+ method="POST",
+ )
+ try:
+ with urlopen(req, timeout=10) as r:
+ return r.status, r.read()
+ except HTTPError as e:
+ return e.code, e.read()
+
+
+def test_failing_python_embeddings():
+ @embedding_server(address="0.0.0.0:7342")
+ def failing_embeddings(text: str):
+ assert False
-def embedding(texts):
- return [single_embedding(text) for text in texts]
+ with failing_embeddings.start():
+ payload = {"model": "whatever", "input": ["Hello world"]}
+ status, _ = post_json("http://localhost:7342/embeddings", payload)
+ assert status == 500
+ status, _ = post_json("http://localhost:7342/embeddings", payload)
+ assert status == 500
def floats_are_equals(float1: float, float2: float) -> bool:
- return float1 + 0.001 > float2 and float1 - 0.001 < float2
+ return float1 + 0.00001 > float2 and float1 - 0.01 < float2
# the graph generated by this function looks like this:
@@ -48,26 +80,34 @@ def create_graph() -> VectorisedGraph:
g.add_edge(3, "node1", "node3", {"name": "edge2"})
g.add_edge(4, "node3", "node4", {"name": "edge3"})
- vg = g.vectorise(embedding, nodes="{{ name }}", edges="{{ properties.name }}")
+ embeddings = OpenAIEmbeddings(api_base="http://localhost:7340")
+ vg = g.vectorise(embeddings, nodes="{{ name }}", edges="{{ properties.name }}")
return vg
+def test_embedding_sever_context_manager():
+ @embedding_server(address="0.0.0.0:7341")
+ def constant(text: str):
+ return [1.0]
+
+ with constant.start():
+ payload = {
+ # "model": "whatever",
+ "input": ["The text to vectorise"]
+ }
+ status, body = post_json("http://localhost:7341/embeddings", payload)
+ assert status == 200
+ result = json.loads(body)
+ vector = result["data"][0]["embedding"]
+ assert vector == [1.0]
+
+
def test_selection():
vg = create_graph()
- ################################
- selection = vg.empty_selection()
- nodes_to_select = ["node1", "node2"]
- edges_to_select = [("node1", "node2"), ("node1", "node3")]
- selection = vg.empty_selection()
- selection.add_nodes(nodes_to_select)
- selection.add_edges(edges_to_select)
- nodes = selection.nodes()
- ###########################
-
assert len(vg.empty_selection().get_documents()) == 0
- assert len(vg.empty_selection().get_documents_with_scores()) == 0
+ assert len(vg.empty_selection().get_documents_with_distances()) == 0
nodes_to_select = ["node1", "node2"]
edges_to_select = [("node1", "node2"), ("node1", "node3")]
@@ -77,7 +117,9 @@ def test_selection():
nodes = selection.nodes()
node_names_returned = [node.name for node in nodes]
assert node_names_returned == nodes_to_select
+ print("before get documents")
docs = [doc.content for doc in selection.get_documents()]
+ print("after get documents")
assert docs == ["node1", "node2"]
selection = vg.empty_selection()
@@ -113,8 +155,10 @@ def test_search():
assert edge_names_returned == [("node1", "node2")]
# TODO: same for edges ?
- [(doc1, score1)] = vg.entities_by_similarity("node1", 1).get_documents_with_scores()
- assert floats_are_equals(score1, 1.0)
+ [(doc1, distance1)] = vg.entities_by_similarity(
+ "node1", 1
+ ).get_documents_with_distances()
+ assert floats_are_equals(distance1, 0.0)
assert (doc1.entity.name, doc1.content) == ("node1", "node1")
# chained search
@@ -205,8 +249,9 @@ def test_filtering_by_entity_type():
assert contents == ["edge1", "edge2", "edge3"]
-def constant_embedding(texts):
- return [[1.0, 0.0, 0.0] for text in texts]
+@embedding_server(address="0.0.0.0:7341")
+def constant_embedding(_text):
+ return [1.0, 0.0, 0.0]
def test_default_template():
@@ -214,7 +259,9 @@ def test_default_template():
g.add_node(1, "node1")
g.add_edge(2, "node1", "node1")
- vg = g.vectorise(constant_embedding)
+ running = constant_embedding.start()
+
+ vg = g.vectorise(OpenAIEmbeddings(api_base="http://localhost:7341"))
node_docs = vg.nodes_by_similarity(query="whatever", limit=10).get_documents()
assert len(node_docs) == 1
@@ -226,3 +273,5 @@ def test_default_template():
edge_docs[0].content
== "There is an edge from node1 to node1 with events at:\n- Jan 1 1970 00:00\n"
)
+
+ running.stop()
diff --git a/raphtory-benchmark/src/common/vectors.rs b/raphtory-benchmark/src/common/vectors.rs
index 701ace6db2..e5d5299461 100644
--- a/raphtory-benchmark/src/common/vectors.rs
+++ b/raphtory-benchmark/src/common/vectors.rs
@@ -4,8 +4,9 @@ use rand::{rngs::StdRng, Rng, SeedableRng};
use raphtory::{
prelude::{AdditionOps, Graph, NO_PROPS},
vectors::{
- cache::VectorCache, embeddings::EmbeddingResult, template::DocumentTemplate,
- vectorisable::Vectorisable, vectorised_graph::VectorisedGraph, Embedding,
+ cache::VectorCache, embeddings::EmbeddingResult, storage::OpenAIEmbeddings,
+ template::DocumentTemplate, vectorisable::Vectorisable, vectorised_graph::VectorisedGraph,
+ Embedding,
},
};
use tokio::runtime::Runtime;
@@ -35,12 +36,22 @@ pub fn create_graph_for_vector_bench(size: usize) -> Graph {
}
pub async fn vectorise_graph_for_bench_async(graph: Graph) -> VectorisedGraph {
- let cache = VectorCache::in_memory(embedding_model);
+ let cache = VectorCache::in_memory();
+ let model = cache
+ .openai(OpenAIEmbeddings {
+ model: "whatever".to_owned(),
+ api_base: Some("localhost://1783".to_owned()), // TODO: run embedding server as well on the background so that this works
+ api_key_env: None,
+ project_id: None,
+ org_id: None,
+ })
+ .await
+ .unwrap();
let template = DocumentTemplate {
node_template: Some("{{name}}".to_owned()),
edge_template: None,
};
- graph.vectorise(cache, template, None, true).await.unwrap()
+ graph.vectorise(model, template, None, true).await.unwrap()
}
// TODO: remove this version
diff --git a/raphtory-graphql/schema.graphql b/raphtory-graphql/schema.graphql
index 893c3ea2d1..8ed275b41b 100644
--- a/raphtory-graphql/schema.graphql
+++ b/raphtory-graphql/schema.graphql
@@ -711,6 +711,13 @@ type EdgesWindowSet {
list: [Edges!]!
}
+input EmbeddingModel @oneOf {
+ """
+ OpenAI embedding models or compatible providers
+ """
+ openAI: OpenAIConfig
+}
+
"""
Raphtory’s EventTime.
Represents a unique timepoint in the graph’s history as (timestamp, event_id).
@@ -2207,6 +2214,14 @@ input ObjectEntry {
value: Value!
}
+input OpenAIConfig {
+ model: String!
+ apiBase: String
+ apiKeyEnv: String
+ orgId: String
+ projectId: String
+}
+
enum Operator {
"""
Equality operator.
@@ -2546,6 +2561,12 @@ type QueryRoot {
"""
updateGraph(path: String!): MutableGraph!
"""
+ Update graph query, has side effects to update graph state
+
+ Returns:: GqlMutableGraph
+ """
+ vectoriseGraph(path: String!, model: EmbeddingModel, nodes: Template, edges: Template): Boolean!
+ """
Create vectorised graph in the format used for queries
Returns:: GqlVectorisedGraph
@@ -2612,6 +2633,17 @@ enum SortByTime {
EARLIEST
}
+input Template @oneOf {
+ """
+ The default template.
+ """
+ enabled: Boolean
+ """
+ A custom template.
+ """
+ custom: String
+}
+
type TemporalProperties {
"""
Get property value matching the specified key.
@@ -2775,6 +2807,10 @@ type VectorSelection {
}
type VectorisedGraph {
+ """
+ Optmize the vector index
+ """
+ optimizeIndex: Boolean!
"""
Returns an empty selection of documents.
"""
diff --git a/raphtory-graphql/src/data.rs b/raphtory-graphql/src/data.rs
index dbcaa30c6a..fa6b80b561 100644
--- a/raphtory-graphql/src/data.rs
+++ b/raphtory-graphql/src/data.rs
@@ -8,15 +8,14 @@ use itertools::Itertools;
use moka::future::Cache;
use raphtory::{
db::api::view::MaterializedGraph,
- errors::{GraphError, InvalidPathReason},
+ errors::{GraphError, GraphResult, InvalidPathReason},
prelude::CacheOps,
vectors::{
- cache::VectorCache, template::DocumentTemplate, vectorisable::Vectorisable,
- vectorised_graph::VectorisedGraph,
+ cache::CachedEmbeddingModel, storage::LazyDiskVectorCache, template::DocumentTemplate,
+ vectorisable::Vectorisable, vectorised_graph::VectorisedGraph,
},
};
use std::{
- collections::HashMap,
path::{Path, PathBuf},
sync::Arc,
};
@@ -24,13 +23,6 @@ use tokio::fs;
use tracing::{error, warn};
use walkdir::WalkDir;
-#[derive(Clone)]
-pub struct EmbeddingConf {
- pub(crate) cache: VectorCache,
- pub(crate) global_template: Option,
- pub(crate) individual_templates: HashMap,
-}
-
pub(crate) fn get_relative_path(
work_dir: PathBuf,
path: &Path,
@@ -46,7 +38,6 @@ pub(crate) fn get_relative_path(
.ok_or(InvalidPathReason::NonUTFCharacters)
})
.collect::, _>>()?;
- //a safe unwrap as checking above
let path_str = components.into_iter().join("/");
valid_path(work_dir, &path_str, namespace)?;
Ok(path_str)
@@ -54,10 +45,10 @@ pub(crate) fn get_relative_path(
#[derive(Clone)]
pub struct Data {
- pub(crate) work_dir: PathBuf,
+ pub(crate) work_dir: PathBuf, // TODO: move this to config?
cache: Cache,
- pub(crate) create_index: bool,
- pub(crate) embedding_conf: Option,
+ pub(crate) create_index: bool, // TODO: move this to config?
+ pub(crate) vector_cache: LazyDiskVectorCache,
}
impl Data {
@@ -80,11 +71,13 @@ impl Data {
#[cfg(not(feature = "search"))]
let create_index = false;
+ // TODO: make vector feature optional?
+
Self {
work_dir: work_dir.to_path_buf(),
cache,
create_index,
- embedding_conf: Default::default(),
+ vector_cache: LazyDiskVectorCache::new(work_dir.join(".vector-cache")),
}
}
@@ -116,8 +109,7 @@ impl Data {
let folder_clone = folder.clone();
let graph_clone = graph.clone();
blocking_io(move || graph_clone.cache(folder_clone)).await?;
- let vectors = self.vectorise(graph.clone(), &folder).await;
- let graph = GraphWithVectors::new(graph, vectors);
+ let graph = GraphWithVectors::new(graph, None);
graph
.folder
.get_or_try_init(|| Ok::<_, GraphError>(folder.into()))?;
@@ -134,23 +126,16 @@ impl Data {
Ok(())
}
- fn resolve_template(&self, graph: &Path) -> Option<&DocumentTemplate> {
- let conf = self.embedding_conf.as_ref()?;
- conf.individual_templates
- .get(graph)
- .or(conf.global_template.as_ref())
- }
-
async fn vectorise_with_template(
&self,
graph: MaterializedGraph,
folder: &ValidGraphFolder,
template: &DocumentTemplate,
+ model: CachedEmbeddingModel,
) -> Option> {
- let conf = self.embedding_conf.as_ref()?;
let vectors = graph
.vectorise(
- conf.cache.clone(),
+ model,
template.clone(),
Some(&folder.get_vectors_path()),
true, // verbose
@@ -166,34 +151,18 @@ impl Data {
}
}
- async fn vectorise(
+ pub(crate) async fn vectorise_folder(
&self,
- graph: MaterializedGraph,
- folder: &ValidGraphFolder,
- ) -> Option> {
- let template = self.resolve_template(folder.get_original_path())?;
- self.vectorise_with_template(graph, folder, template).await
- }
-
- async fn vectorise_folder(&self, folder: &ExistingGraphFolder) -> Option<()> {
- // it's important that we check if there is a valid template set for this graph path
- // before actually loading the graph, otherwise we are loading the graph for no reason
- let template = self.resolve_template(folder.get_original_path())?;
- let graph = self
- .read_graph_from_folder(folder.clone())
- .await
- .ok()?
- .graph;
- self.vectorise_with_template(graph, folder, template).await;
- Some(())
- }
-
- pub(crate) async fn vectorise_all_graphs_that_are_not(&self) -> Result<(), GraphError> {
- for folder in self.get_all_graph_folders() {
- if !folder.get_vectors_path().exists() {
- self.vectorise_folder(&folder).await;
- }
- }
+ folder: &ExistingGraphFolder,
+ template: &DocumentTemplate,
+ model: CachedEmbeddingModel,
+ ) -> GraphResult<()> {
+ let graph = self.read_graph_from_folder(folder.clone()).await?.graph;
+ self.vectorise_with_template(graph, folder, template, model)
+ .await;
+ self.cache
+ .remove(&folder.get_original_path().to_path_buf())
+ .await;
Ok(())
}
@@ -216,9 +185,8 @@ impl Data {
&self,
folder: ExistingGraphFolder,
) -> Result {
- let cache = self.embedding_conf.as_ref().map(|conf| conf.cache.clone());
- let create_index = self.create_index;
- blocking_io(move || GraphWithVectors::read_from_folder(&folder, cache, create_index)).await
+ GraphWithVectors::read_from_folder(&folder, &self.vector_cache, self.create_index).await
+ // FIXME: I need some blocking_io inside of GraphWithVectors::read_from_folder
}
}
@@ -276,7 +244,7 @@ pub(crate) mod data_tests {
File::create(path.join("graph")).unwrap();
}
- pub(crate) fn save_graphs_to_work_dir(
+ pub(crate) async fn save_graphs_to_work_dir(
work_dir: &Path,
graphs: &HashMap,
) -> Result<(), GraphError> {
@@ -368,7 +336,9 @@ pub(crate) mod data_tests {
#[cfg(feature = "storage")]
graphs.insert("test_dg".to_string(), graph2);
- save_graphs_to_work_dir(tmp_work_dir.path(), &graphs).unwrap();
+ save_graphs_to_work_dir(tmp_work_dir.path(), &graphs)
+ .await
+ .unwrap();
let data = Data::new(tmp_work_dir.path(), &Default::default());
diff --git a/raphtory-graphql/src/embeddings.rs b/raphtory-graphql/src/embeddings.rs
deleted file mode 100644
index d65b59ee6e..0000000000
--- a/raphtory-graphql/src/embeddings.rs
+++ /dev/null
@@ -1,16 +0,0 @@
-use crate::data::Data;
-use async_graphql::Context;
-use raphtory::{errors::GraphResult, vectors::Embedding};
-
-pub(crate) trait EmbedQuery {
- async fn embed_query(&self, text: String) -> GraphResult;
-}
-
-impl EmbedQuery for Context<'_> {
- /// this is meant to be called from a vector context, so the embedding conf is assumed to exist
- async fn embed_query(&self, text: String) -> GraphResult {
- let data = self.data_unchecked::();
- let cache = &data.embedding_conf.as_ref().unwrap().cache;
- cache.get_single(text).await
- }
-}
diff --git a/raphtory-graphql/src/graph.rs b/raphtory-graphql/src/graph.rs
index 50a3468e60..fb47251b53 100644
--- a/raphtory-graphql/src/graph.rs
+++ b/raphtory-graphql/src/graph.rs
@@ -15,7 +15,9 @@ use raphtory::{
prelude::{CacheOps, EdgeViewOps, IndexMutationOps},
serialise::GraphFolder,
storage::core_ops::CoreGraphOps,
- vectors::{cache::VectorCache, vectorised_graph::VectorisedGraph},
+ vectors::{
+ cache::VectorCache, storage::LazyDiskVectorCache, vectorised_graph::VectorisedGraph,
+ },
};
use raphtory_storage::{
core_ops::InheritCoreGraphOps, graph::graph::GraphStorage, layer_ops::InheritLayerOps,
@@ -76,9 +78,9 @@ impl GraphWithVectors {
}
}
- pub(crate) fn read_from_folder(
+ pub(crate) async fn read_from_folder(
folder: &ExistingGraphFolder,
- cache: Option,
+ cache: &LazyDiskVectorCache,
create_index: bool,
) -> Result {
let graph_path = &folder.get_graph_path();
@@ -87,9 +89,11 @@ impl GraphWithVectors {
} else {
MaterializedGraph::load_cached(folder.clone())?
};
- let vectors = cache.and_then(|cache| {
- VectorisedGraph::read_from_path(&folder.get_vectors_path(), graph.clone(), cache).ok()
- });
+ let vectors =
+ VectorisedGraph::read_from_path(&folder.get_vectors_path(), graph.clone(), cache)
+ .await
+ .ok();
+
println!("Graph loaded = {}", folder.get_original_path_str());
if create_index {
graph.create_index()?;
diff --git a/raphtory-graphql/src/lib.rs b/raphtory-graphql/src/lib.rs
index 0706e8d966..da18c17492 100644
--- a/raphtory-graphql/src/lib.rs
+++ b/raphtory-graphql/src/lib.rs
@@ -1,7 +1,6 @@
pub use crate::server::GraphServer;
mod auth;
pub mod data;
-mod embeddings;
mod graph;
pub mod model;
pub mod observability;
@@ -99,7 +98,9 @@ mod graphql_test {
let graphs = HashMap::from([("master".to_string(), graph)]);
let tmp_dir = tempdir().unwrap();
- save_graphs_to_work_dir(tmp_dir.path(), &graphs).unwrap();
+ save_graphs_to_work_dir(tmp_dir.path(), &graphs)
+ .await
+ .unwrap();
let config = AppConfigBuilder::new().with_create_index(true).build();
let data = Data::new(tmp_dir.path(), &config);
@@ -199,7 +200,9 @@ mod graphql_test {
let graph: MaterializedGraph = graph.into();
let graphs = HashMap::from([("lotr".to_string(), graph)]);
let tmp_dir = tempdir().unwrap();
- save_graphs_to_work_dir(tmp_dir.path(), &graphs).unwrap();
+ save_graphs_to_work_dir(tmp_dir.path(), &graphs)
+ .await
+ .unwrap();
let data = Data::new(tmp_dir.path(), &AppConfig::default());
@@ -310,7 +313,9 @@ mod graphql_test {
let graphs = HashMap::from([("graph".to_string(), graph)]);
let tmp_dir = tempdir().unwrap();
- save_graphs_to_work_dir(tmp_dir.path(), &graphs).unwrap();
+ save_graphs_to_work_dir(tmp_dir.path(), &graphs)
+ .await
+ .unwrap();
let data = Data::new(tmp_dir.path(), &AppConfig::default());
let schema = App::create_schema().data(data).finish().unwrap();
@@ -413,7 +418,9 @@ mod graphql_test {
let graphs = HashMap::from([("graph".to_string(), graph)]);
let tmp_dir = tempdir().unwrap();
- save_graphs_to_work_dir(tmp_dir.path(), &graphs).unwrap();
+ save_graphs_to_work_dir(tmp_dir.path(), &graphs)
+ .await
+ .unwrap();
let data = Data::new(tmp_dir.path(), &AppConfig::default());
let schema = App::create_schema().data(data).finish().unwrap();
@@ -478,7 +485,9 @@ mod graphql_test {
let graph: MaterializedGraph = g.into();
let graphs = HashMap::from([("graph".to_string(), graph)]);
let tmp_dir = tempdir().unwrap();
- save_graphs_to_work_dir(tmp_dir.path(), &graphs).unwrap();
+ save_graphs_to_work_dir(tmp_dir.path(), &graphs)
+ .await
+ .unwrap();
let expected = json!({
"graph": {
@@ -629,7 +638,9 @@ mod graphql_test {
let g = g.into();
let graphs = HashMap::from([("graph".to_string(), g)]);
let tmp_dir = tempdir().unwrap();
- save_graphs_to_work_dir(tmp_dir.path(), &graphs).unwrap();
+ save_graphs_to_work_dir(tmp_dir.path(), &graphs)
+ .await
+ .unwrap();
let data = Data::new(tmp_dir.path(), &AppConfig::default());
let schema = App::create_schema().data(data).finish().unwrap();
@@ -956,7 +967,9 @@ mod graphql_test {
let graph = graph.into();
let graphs = HashMap::from([("graph".to_string(), graph)]);
let tmp_dir = tempdir().unwrap();
- save_graphs_to_work_dir(tmp_dir.path(), &graphs).unwrap();
+ save_graphs_to_work_dir(tmp_dir.path(), &graphs)
+ .await
+ .unwrap();
let data = Data::new(tmp_dir.path(), &AppConfig::default());
let schema = App::create_schema().data(data).finish().unwrap();
@@ -1131,7 +1144,9 @@ mod graphql_test {
let graph = graph.into();
let graphs = HashMap::from([("graph".to_string(), graph)]);
let tmp_dir = tempdir().unwrap();
- save_graphs_to_work_dir(tmp_dir.path(), &graphs).unwrap();
+ save_graphs_to_work_dir(tmp_dir.path(), &graphs)
+ .await
+ .unwrap();
let data = Data::new(tmp_dir.path(), &AppConfig::default());
let schema = App::create_schema().data(data).finish().unwrap();
@@ -1272,7 +1287,9 @@ mod graphql_test {
("graph6".to_string(), graph6.into()),
]);
let tmp_dir = tempdir().unwrap();
- save_graphs_to_work_dir(tmp_dir.path(), &graphs).unwrap();
+ save_graphs_to_work_dir(tmp_dir.path(), &graphs)
+ .await
+ .unwrap();
let data = Data::new(tmp_dir.path(), &AppConfig::default());
let schema = App::create_schema().data(data).finish().unwrap();
@@ -1569,7 +1586,9 @@ mod graphql_test {
let graph = graph.into();
let graphs = HashMap::from([("graph".to_string(), graph)]);
let tmp_dir = tempdir().unwrap();
- save_graphs_to_work_dir(tmp_dir.path(), &graphs).unwrap();
+ save_graphs_to_work_dir(tmp_dir.path(), &graphs)
+ .await
+ .unwrap();
let data = Data::new(tmp_dir.path(), &AppConfig::default());
let schema = App::create_schema().data(data).finish().unwrap();
diff --git a/raphtory-graphql/src/main.rs b/raphtory-graphql/src/main.rs
index b42ab3e5e6..f75272e635 100644
--- a/raphtory-graphql/src/main.rs
+++ b/raphtory-graphql/src/main.rs
@@ -105,7 +105,8 @@ async fn main() -> IoResult<()> {
let app_config = Some(builder.build());
- GraphServer::new(args.working_dir, app_config, None)?
+ GraphServer::new(args.working_dir, app_config, None)
+ .await?
.run_with_port(args.port)
.await?;
}
diff --git a/raphtory-graphql/src/model/graph/mutable_graph.rs b/raphtory-graphql/src/model/graph/mutable_graph.rs
index dbccd49d9e..b826cdd2c3 100644
--- a/raphtory-graphql/src/model/graph/mutable_graph.rs
+++ b/raphtory-graphql/src/model/graph/mutable_graph.rs
@@ -598,32 +598,22 @@ impl GqlMutableEdge {
#[cfg(test)]
mod tests {
use super::*;
- use crate::{
- config::app_config::AppConfig,
- data::{Data, EmbeddingConf},
- };
+ use crate::{config::app_config::AppConfig, data::Data};
use itertools::Itertools;
use raphtory::{
db::api::view::MaterializedGraph,
vectors::{
- cache::VectorCache, embeddings::EmbeddingResult, template::DocumentTemplate, Embedding,
+ custom::{serve_custom_embedding, EmbeddingServer},
+ embeddings::EmbeddingResult,
+ storage::OpenAIEmbeddings,
+ template::DocumentTemplate,
+ Embedding,
},
};
- use std::collections::HashMap;
use tempfile::tempdir;
- async fn fake_embedding(texts: Vec) -> EmbeddingResult> {
- Ok(texts
- .into_iter()
- .map(|_| vec![1.0, 0.0, 0.0].into())
- .collect_vec())
- }
-
- fn custom_template() -> DocumentTemplate {
- DocumentTemplate {
- node_template: Some("{{ name }} is a {{ node_type }}".to_string()),
- edge_template: Some("{{ src.name }} appeared with {{ dst.name}}".to_string()),
- }
+ fn fake_embedding(_: &str) -> Vec {
+ vec![1.0]
}
fn create_test_graph() -> MaterializedGraph {
@@ -631,42 +621,66 @@ mod tests {
graph.into()
}
- async fn create_mutable_graph() -> (GqlMutableGraph, tempfile::TempDir) {
+ async fn create_mutable_graph(
+ port: u16,
+ ) -> (GqlMutableGraph, tempfile::TempDir, EmbeddingServer) {
let graph = create_test_graph();
let tmp_dir = tempdir().unwrap();
let config = AppConfig::default();
- let mut data = Data::new(tmp_dir.path(), &config);
+ let data = Data::new(tmp_dir.path(), &config);
+
+ let graph_name = "test_graph";
- // Override the embedding function with a mock for testing.
- data.embedding_conf = Some(EmbeddingConf {
- cache: VectorCache::in_memory(fake_embedding),
- global_template: Some(custom_template()),
- individual_templates: HashMap::new(),
- });
+ data.insert_graph(graph_name, graph).await.unwrap();
- data.insert_graph("test_graph", graph).await.unwrap();
+ let template = DocumentTemplate {
+ node_template: Some("{{ name }} is a {{ node_type }}".to_string()),
+ edge_template: Some("{{ src.name }} appeared with {{ dst.name}}".to_string()),
+ };
+
+ let address = format!("0.0.0.0:{port}");
+ let embedding_server = serve_custom_embedding(&address, fake_embedding).await;
+
+ let api_base = format!("http://localhost:{port}");
+ let config = OpenAIEmbeddings {
+ model: "whatever".to_owned(),
+ api_base: Some(api_base),
+ api_key_env: None,
+ project_id: None,
+ org_id: None,
+ };
+ let vector_cache = data.vector_cache.resolve().await.unwrap();
+ let model = vector_cache.openai(config).await.unwrap();
+ data.vectorise_folder(
+ &ExistingGraphFolder::try_from(tmp_dir.path().to_path_buf(), graph_name).unwrap(),
+ &template,
+ model,
+ )
+ .await
+ .unwrap();
- let (graph_with_vectors, path) = data.get_graph("test_graph").await.unwrap();
+ let (graph_with_vectors, path) = data.get_graph(graph_name).await.unwrap();
let mutable_graph = GqlMutableGraph::new(path, graph_with_vectors);
- (mutable_graph, tmp_dir)
+ (mutable_graph, tmp_dir, embedding_server)
}
#[tokio::test]
async fn test_add_nodes_empty_list() {
- let (mutable_graph, _tmp_dir) = create_mutable_graph().await;
+ let (mutable_graph, _tmp_dir, _embedding_server) = create_mutable_graph(1745).await;
let nodes = vec![];
let result = mutable_graph.add_nodes(nodes).await;
assert!(result.is_ok());
assert!(result.unwrap());
+ _embedding_server.stop();
}
#[tokio::test]
async fn test_add_nodes_simple() {
- let (mutable_graph, _tmp_dir) = create_mutable_graph().await;
+ let (mutable_graph, _tmp_dir, _embedding_server) = create_mutable_graph(1746).await;
let nodes = vec![
NodeAddition {
@@ -694,22 +708,22 @@ mod tests {
assert!(result.is_ok());
assert!(result.unwrap());
- let query = "node1".to_string();
- let embedding = &fake_embedding(vec![query]).await.unwrap().remove(0);
+ let embedding = fake_embedding("node1");
let limit = 5;
let result = mutable_graph
.graph
.vectors
.unwrap()
- .nodes_by_similarity(embedding, limit, None);
+ .nodes_by_similarity(&embedding.into(), limit, None)
+ .await;
assert!(result.is_ok());
- assert!(result.unwrap().get_documents().unwrap().len() == 2);
+ assert!(result.unwrap().get_documents().await.unwrap().len() == 2);
}
#[tokio::test]
async fn test_add_nodes_with_properties() {
- let (mutable_graph, _tmp_dir) = create_mutable_graph().await;
+ let (mutable_graph, _tmp_dir, _embedding_server) = create_mutable_graph(1747).await;
let nodes = vec![
NodeAddition {
@@ -764,22 +778,22 @@ mod tests {
assert!(result.is_ok());
assert!(result.unwrap());
- let query = "complex_node_1".to_string();
- let embedding = &fake_embedding(vec![query]).await.unwrap().remove(0);
+ let embedding = fake_embedding("complex_node_1");
let limit = 5;
let result = mutable_graph
.graph
.vectors
.unwrap()
- .nodes_by_similarity(embedding, limit, None);
+ .nodes_by_similarity(&embedding.into(), limit, None)
+ .await;
assert!(result.is_ok());
- assert!(result.unwrap().get_documents().unwrap().len() == 3);
+ assert!(result.unwrap().get_documents().await.unwrap().len() == 3);
}
#[tokio::test]
async fn test_add_edges_simple() {
- let (mutable_graph, _tmp_dir) = create_mutable_graph().await;
+ let (mutable_graph, _tmp_dir, _embedding_server) = create_mutable_graph(1748).await;
// First add some nodes.
let nodes = vec![
@@ -839,16 +853,16 @@ mod tests {
assert!(result.unwrap());
// Test that edge embeddings were generated.
- let query = "node1 appeared with node2".to_string();
- let embedding = &fake_embedding(vec![query]).await.unwrap().remove(0);
+ let embedding = fake_embedding("node1 appeared with node2");
let limit = 5;
let result = mutable_graph
.graph
.vectors
.unwrap()
- .edges_by_similarity(embedding, limit, None);
+ .edges_by_similarity(&embedding.into(), limit, None)
+ .await;
assert!(result.is_ok());
- assert!(result.unwrap().get_documents().unwrap().len() == 2);
+ assert!(result.unwrap().get_documents().await.unwrap().len() == 2);
}
}
diff --git a/raphtory-graphql/src/model/graph/vector_selection.rs b/raphtory-graphql/src/model/graph/vector_selection.rs
index 74dce55449..5a3cba81b4 100644
--- a/raphtory-graphql/src/model/graph/vector_selection.rs
+++ b/raphtory-graphql/src/model/graph/vector_selection.rs
@@ -4,12 +4,12 @@ use super::{
node::GqlNode,
vectorised_graph::{IntoWindowTuple, VectorisedGraphWindow},
};
-use crate::{embeddings::EmbedQuery, rayon::blocking_compute};
-use async_graphql::Context;
+use crate::rayon::blocking_compute;
use dynamic_graphql::{InputObject, ResolvedObject, ResolvedObjectFields};
use raphtory::{
- db::api::view::MaterializedGraph, errors::GraphResult,
- vectors::vector_selection::VectorSelection,
+ db::api::view::MaterializedGraph,
+ errors::GraphResult,
+ vectors::{vector_selection::VectorSelection, Embedding},
};
#[derive(InputObject)]
@@ -45,18 +45,15 @@ impl GqlVectorSelection {
/// Returns a list of documents in the current selection.
async fn get_documents(&self) -> GraphResult> {
let cloned = self.0.clone();
- blocking_compute(move || {
- let docs = cloned.get_documents_with_scores()?.into_iter();
- Ok(docs
- .map(|(doc, score)| GqlDocument {
- content: doc.content,
- entity: doc.entity.into(),
- embedding: doc.embedding.to_vec(),
- score,
- })
- .collect())
- })
- .await
+ let docs = cloned.get_documents_with_distances().await?.into_iter();
+ Ok(docs
+ .map(|(doc, score)| GqlDocument {
+ content: doc.content,
+ entity: doc.entity.into(),
+ embedding: doc.embedding.to_vec(),
+ score,
+ })
+ .collect())
}
/// Adds all the documents associated with the specified nodes to the current selection.
@@ -64,11 +61,8 @@ impl GqlVectorSelection {
/// Documents added by this call are assumed to have a score of 0.
async fn add_nodes(&self, nodes: Vec) -> Self {
let mut selection = self.cloned();
- blocking_compute(move || {
- selection.add_nodes(nodes);
- selection.into()
- })
- .await
+ selection.add_nodes(nodes);
+ selection.into()
}
/// Adds all the documents associated with the specified edges to the current selection.
@@ -76,12 +70,9 @@ impl GqlVectorSelection {
/// Documents added by this call are assumed to have a score of 0.
async fn add_edges(&self, edges: Vec) -> Self {
let mut selection = self.cloned();
- blocking_compute(move || {
- let edges = edges.into_iter().map(|edge| (edge.src, edge.dst)).collect();
- selection.add_edges(edges);
- selection.into()
- })
- .await
+ let edges = edges.into_iter().map(|edge| (edge.src, edge.dst)).collect();
+ selection.add_edges(edges);
+ selection.into()
}
/// Add all the documents a specified number of hops away to the selection.
@@ -100,55 +91,49 @@ impl GqlVectorSelection {
/// Adds documents, from the set of one hop neighbours to the current selection, to the selection based on their similarity score with the specified query. This function loops so that the set of one hop neighbours expands on each loop and number of documents added is determined by the specified limit.
async fn expand_entities_by_similarity(
&self,
- ctx: &Context<'_>,
query: String,
limit: usize,
window: Option,
) -> GraphResult {
- let vector = ctx.embed_query(query).await?;
+ let vector = self.embed_text(query).await?;
let window = window.into_window_tuple();
let mut selection = self.cloned();
- blocking_compute(move || {
- selection.expand_entities_by_similarity(&vector, limit, window)?;
- Ok(selection.into())
- })
- .await
+ selection
+ .expand_entities_by_similarity(&vector, limit, window)
+ .await?;
+ Ok(selection.into())
}
/// Add the adjacent nodes with higher score for query to the selection up to a specified limit. This function loops like expand_entities_by_similarity but is restricted to nodes.
async fn expand_nodes_by_similarity(
&self,
- ctx: &Context<'_>,
query: String,
limit: usize,
window: Option,
) -> GraphResult {
- let vector = ctx.embed_query(query).await?;
+ let vector = self.embed_text(query).await?;
let window = window.into_window_tuple();
let mut selection = self.cloned();
- blocking_compute(move || {
- selection.expand_nodes_by_similarity(&vector, limit, window)?;
- Ok(selection.into())
- })
- .await
+ selection
+ .expand_nodes_by_similarity(&vector, limit, window)
+ .await?;
+ Ok(selection.into())
}
/// Add the adjacent edges with higher score for query to the selection up to a specified limit. This function loops like expand_entities_by_similarity but is restricted to edges.
async fn expand_edges_by_similarity(
&self,
- ctx: &Context<'_>,
query: String,
limit: usize,
window: Option,
) -> GraphResult {
- let vector = ctx.embed_query(query).await?;
+ let vector = self.embed_text(query).await?;
let window = window.into_window_tuple();
let mut selection = self.cloned();
- blocking_compute(move || {
- selection.expand_edges_by_similarity(&vector, limit, window)?;
- Ok(selection.into())
- })
- .await
+ selection
+ .expand_edges_by_similarity(&vector, limit, window)
+ .await?;
+ Ok(selection.into())
}
}
@@ -156,4 +141,8 @@ impl GqlVectorSelection {
fn cloned(&self) -> VectorSelection {
self.0.clone()
}
+
+ async fn embed_text(&self, text: String) -> GraphResult {
+ self.0.get_vectorised_graph().embed_text(text).await
+ }
}
diff --git a/raphtory-graphql/src/model/graph/vectorised_graph.rs b/raphtory-graphql/src/model/graph/vectorised_graph.rs
index 55920fc76c..68a5a17415 100644
--- a/raphtory-graphql/src/model/graph/vectorised_graph.rs
+++ b/raphtory-graphql/src/model/graph/vectorised_graph.rs
@@ -1,6 +1,4 @@
use super::vector_selection::GqlVectorSelection;
-use crate::{embeddings::EmbedQuery, model::blocking_io};
-use async_graphql::Context;
use dynamic_graphql::{InputObject, ResolvedObject, ResolvedObjectFields};
use raphtory::{
db::api::view::MaterializedGraph, errors::GraphResult,
@@ -37,6 +35,12 @@ impl From> for GqlVectorisedGraph {
#[ResolvedObjectFields]
impl GqlVectorisedGraph {
+ /// Optmize the vector index
+ async fn optimize_index(&self) -> GraphResult {
+ self.0.optimize_index().await?;
+ Ok(true)
+ }
+
/// Returns an empty selection of documents.
async fn empty_selection(&self) -> GqlVectorSelection {
self.0.empty_selection().into()
@@ -45,42 +49,42 @@ impl GqlVectorisedGraph {
/// Search the top scoring entities according to a specified query returning no more than a specified limit of entities.
async fn entities_by_similarity(
&self,
- ctx: &Context<'_>,
query: String,
limit: usize,
window: Option,
) -> GraphResult {
- let vector = ctx.embed_query(query).await?;
+ let vector = self.0.embed_text(query).await?;
let w = window.into_window_tuple();
let cloned = self.0.clone();
- blocking_io(move || Ok(cloned.entities_by_similarity(&vector, limit, w)?.into())).await
+ Ok(cloned
+ .entities_by_similarity(&vector, limit, w)
+ .await?
+ .into())
}
/// Search the top scoring nodes according to a specified query returning no more than a specified limit of nodes.
async fn nodes_by_similarity(
&self,
- ctx: &Context<'_>,
query: String,
limit: usize,
window: Option,
) -> GraphResult {
- let vector = ctx.embed_query(query).await?;
+ let vector = self.0.embed_text(query).await?;
let w = window.into_window_tuple();
let cloned = self.0.clone();
- blocking_io(move || Ok(cloned.nodes_by_similarity(&vector, limit, w)?.into())).await
+ Ok(cloned.nodes_by_similarity(&vector, limit, w).await?.into())
}
/// Search the top scoring edges according to a specified query returning no more than a specified limit of edges.
async fn edges_by_similarity(
&self,
- ctx: &Context<'_>,
query: String,
limit: usize,
window: Option,
) -> GraphResult {
- let vector = ctx.embed_query(query).await?;
+ let vector = self.0.embed_text(query).await?;
let w = window.into_window_tuple();
let cloned = self.0.clone();
- blocking_io(move || Ok(cloned.edges_by_similarity(&vector, limit, w)?.into())).await
+ Ok(cloned.edges_by_similarity(&vector, limit, w).await?.into())
}
}
diff --git a/raphtory-graphql/src/model/mod.rs b/raphtory-graphql/src/model/mod.rs
index eadf0ec654..61d2f7b200 100644
--- a/raphtory-graphql/src/model/mod.rs
+++ b/raphtory-graphql/src/model/mod.rs
@@ -9,20 +9,25 @@ use crate::{
},
plugins::{mutation_plugin::MutationPlugin, query_plugin::QueryPlugin},
},
- paths::valid_path,
+ paths::{valid_path, ExistingGraphFolder},
rayon::blocking_compute,
url_encode::{url_decode_graph, url_encode_graph},
};
use async_graphql::Context;
use dynamic_graphql::{
- App, Enum, Mutation, MutationFields, MutationRoot, ResolvedObject, ResolvedObjectFields,
- Result, Upload,
+ App, Enum, InputObject, Mutation, MutationFields, MutationRoot, OneOfInput, ResolvedObject,
+ ResolvedObjectFields, Result, Upload,
};
use raphtory::{
db::{api::view::MaterializedGraph, graph::views::deletion_graph::PersistentGraph},
- errors::{GraphError, InvalidPathReason},
+ errors::{GraphError, GraphResult, InvalidPathReason},
prelude::*,
serialise::InternalStableDecode,
+ vectors::{
+ cache::CachedEmbeddingModel,
+ storage::OpenAIEmbeddings,
+ template::{DocumentTemplate, DEFAULT_EDGE_TEMPLATE, DEFAULT_NODE_TEMPLATE},
+ },
version,
};
#[cfg(feature = "storage")]
@@ -40,6 +45,47 @@ pub mod plugins;
pub(crate) mod schema;
pub(crate) mod sorting;
+// TODO: move somewhere else
+#[derive(InputObject, Debug, Clone, Default)]
+pub struct OpenAIConfig {
+ model: String,
+ api_base: Option,
+ api_key_env: Option,
+ org_id: Option,
+ project_id: Option,
+}
+
+#[derive(OneOfInput, Clone, Debug)]
+pub enum EmbeddingModel {
+ /// OpenAI embedding models or compatible providers
+ OpenAI(OpenAIConfig),
+}
+
+impl EmbeddingModel {
+ async fn cache<'a>(self, ctx: &Context<'a>) -> GraphResult {
+ let data = ctx.data_unchecked::();
+ match self {
+ Self::OpenAI(OpenAIConfig {
+ model,
+ api_base,
+ api_key_env,
+ org_id,
+ project_id,
+ }) => {
+ let embeddings = OpenAIEmbeddings {
+ model,
+ api_base,
+ api_key_env,
+ org_id,
+ project_id,
+ };
+ let vector_cache = data.vector_cache.resolve().await?;
+ vector_cache.openai(embeddings).await
+ }
+ }
+ }
+}
+
/// a thin wrapper around spawn_blocking that unwraps the join handle
pub(crate) async fn blocking_io(f: F) -> R
where
@@ -87,6 +133,22 @@ pub enum GqlGraphType {
#[graphql(root)]
pub(crate) struct QueryRoot;
+#[derive(OneOfInput, Clone, Debug)]
+pub enum Template {
+ /// The default template.
+ Enabled(bool),
+ /// A custom template.
+ Custom(String),
+}
+
+fn resolve(template: Option, default: &str) -> Option {
+ match template? {
+ Template::Enabled(false) => None,
+ Template::Enabled(true) => Some(default.to_owned()),
+ Template::Custom(template) => Some(template),
+ }
+}
+
#[ResolvedObjectFields]
impl QueryRoot {
/// Hello world demo
@@ -102,6 +164,7 @@ impl QueryRoot {
.await
.map(|(g, folder)| GqlGraph::new(folder, g.graph))?)
}
+
/// Update graph query, has side effects to update graph state
///
/// Returns:: GqlMutableGraph
@@ -116,6 +179,32 @@ impl QueryRoot {
Ok(graph)
}
+ /// Update graph query, has side effects to update graph state
+ ///
+ /// Returns:: GqlMutableGraph
+ async fn vectorise_graph<'a>(
+ ctx: &Context<'a>,
+ path: String,
+ model: Option,
+ nodes: Option,
+ edges: Option,
+ ) -> Result {
+ ctx.require_write_access()?;
+ let data = ctx.data_unchecked::();
+ let template = DocumentTemplate {
+ node_template: resolve(nodes, DEFAULT_NODE_TEMPLATE),
+ edge_template: resolve(edges, DEFAULT_EDGE_TEMPLATE),
+ };
+ let cached_model = model
+ .unwrap_or(EmbeddingModel::OpenAI(Default::default()))
+ .cache(ctx)
+ .await?;
+ let folder = ExistingGraphFolder::try_from(data.work_dir.clone(), &path)?;
+ data.vectorise_folder(&folder, &template, cached_model)
+ .await?;
+ Ok(true)
+ }
+
/// Create vectorised graph in the format used for queries
///
/// Returns:: GqlVectorisedGraph
diff --git a/raphtory-graphql/src/python/server/mod.rs b/raphtory-graphql/src/python/server/mod.rs
index a5bf483fe1..fdd38ac2d1 100644
--- a/raphtory-graphql/src/python/server/mod.rs
+++ b/raphtory-graphql/src/python/server/mod.rs
@@ -1,11 +1,5 @@
-use crate::{
- python::{
- server::{running_server::ServerHandler, server::PyGraphServer},
- RUNNING_SERVER_CONSUMED_MSG,
- },
- GraphServer,
-};
-use pyo3::{exceptions::PyException, PyRefMut, PyResult};
+use crate::python::{server::running_server::ServerHandler, RUNNING_SERVER_CONSUMED_MSG};
+use pyo3::{exceptions::PyException, PyResult};
use raphtory_api::python::error::adapt_err_value;
pub mod running_server;
@@ -15,14 +9,6 @@ pub(crate) enum BridgeCommand {
StopServer,
StopListening,
}
-pub fn take_server_ownership(mut server: PyRefMut) -> PyResult {
- let new_server = server.0.take().ok_or_else(|| {
- PyException::new_err(
- "Server object has already been used, please create another one from scratch",
- )
- })?;
- Ok(new_server)
-}
pub(crate) fn wait_server(running_server: &mut Option) -> PyResult<()> {
let owned_running_server = running_server
diff --git a/raphtory-graphql/src/python/server/server.rs b/raphtory-graphql/src/python/server/server.rs
index e3abd0e85c..75dde63e1e 100644
--- a/raphtory-graphql/src/python/server/server.rs
+++ b/raphtory-graphql/src/python/server/server.rs
@@ -3,24 +3,21 @@ use crate::{
app_config::AppConfigBuilder, auth_config::PUBLIC_KEY_DECODING_ERR_MSG,
otlp_config::TracingLevel,
},
- python::server::{
- running_server::PyRunningGraphServer, take_server_ownership, wait_server, BridgeCommand,
- },
+ python::server::{running_server::PyRunningGraphServer, wait_server, BridgeCommand},
GraphServer,
};
use pyo3::{
exceptions::{PyAttributeError, PyException, PyValueError},
prelude::*,
- types::PyFunction,
};
use raphtory::{
- python::packages::vectors::TemplateConfig,
- vectors::{
- embeddings::{openai_embedding, EmbeddingFunction},
- template::{DocumentTemplate, DEFAULT_EDGE_TEMPLATE, DEFAULT_NODE_TEMPLATE},
+ python::{
+ packages::vectors::{PyOpenAIEmbeddings, TemplateConfig},
+ utils::block_on,
},
+ vectors::template::{DocumentTemplate, DEFAULT_EDGE_TEMPLATE, DEFAULT_NODE_TEMPLATE},
};
-use std::{path::PathBuf, sync::Arc, thread};
+use std::{path::PathBuf, thread};
/// A class for defining and running a Raphtory GraphQL server
///
@@ -38,7 +35,7 @@ use std::{path::PathBuf, sync::Arc, thread};
/// auth_enabled_for_reads:
/// create_index:
#[pyclass(name = "GraphServer", module = "raphtory.graphql")]
-pub struct PyGraphServer(pub Option);
+pub struct PyGraphServer(GraphServer);
impl<'py> IntoPyObject<'py> for GraphServer {
type Target = PyGraphServer;
@@ -46,7 +43,7 @@ impl<'py> IntoPyObject<'py> for GraphServer {
type Error = >::Error;
fn into_pyobject(self, py: Python<'py>) -> Result {
- PyGraphServer::new(self).into_pyobject(py)
+ PyGraphServer(self).into_pyobject(py)
}
}
@@ -61,26 +58,6 @@ fn template_from_python(nodes: TemplateConfig, edges: TemplateConfig) -> Option<
}
}
-impl PyGraphServer {
- pub fn new(server: GraphServer) -> Self {
- Self(Some(server))
- }
-
- fn set_generic_embeddings(
- slf: PyRefMut,
- cache: String,
- embedding: F,
- nodes: TemplateConfig,
- edges: TemplateConfig,
- ) -> PyResult {
- let global_template = template_from_python(nodes, edges);
- let server = take_server_ownership(slf)?;
- let cache = PathBuf::from(cache);
- let rt = tokio::runtime::Runtime::new().unwrap();
- Ok(rt.block_on(server.set_embeddings(embedding, &cache, global_template))?)
- }
-}
-
#[pymethods]
impl PyGraphServer {
#[new]
@@ -148,74 +125,53 @@ impl PyGraphServer {
}
let app_config = Some(app_config_builder.build());
- let server = GraphServer::new(work_dir, app_config, config_path)?;
- Ok(PyGraphServer::new(server))
+ let server = block_on(GraphServer::new(work_dir, app_config, config_path))?;
+ Ok(PyGraphServer(server))
}
+ // TODO: remove this, should be config
/// Turn off index for all graphs
- ///
- /// Returns:
- /// GraphServer: The server with indexing disabled
- fn turn_off_index(slf: PyRefMut) -> PyResult {
- let server = take_server_ownership(slf)?;
- Ok(server.turn_off_index())
+ fn turn_off_index(mut slf: PyRefMut) {
+ slf.0.turn_off_index()
}
- /// Setup the server to vectorise graphs with a default template.
+ /// Vectorise the graph name in the server working directory.
///
/// Arguments:
- /// cache (str): the directory to use as cache for the embeddings.
- /// embedding (Callable, optional): the embedding function to translate documents to embeddings.
- /// nodes (bool | str): if nodes have to be embedded or not or the custom template to use if a str is provided. Defaults to True.
- /// edges (bool | str): if edges have to be embedded or not or the custom template to use if a str is provided. Defaults to True.
- ///
- /// Returns:
- /// GraphServer: A new server object with embeddings setup.
- #[pyo3(
- signature = (cache, embedding = None, nodes = TemplateConfig::Bool(true), edges = TemplateConfig::Bool(true))
- )]
- fn set_embeddings(
- slf: PyRefMut,
- cache: String,
- embedding: Option>,
- nodes: TemplateConfig,
- edges: TemplateConfig,
- ) -> PyResult {
- match embedding {
- Some(embedding) => {
- let embedding: Arc = Arc::new(embedding);
- Self::set_generic_embeddings(slf, cache, embedding, nodes, edges)
- }
- None => Self::set_generic_embeddings(slf, cache, openai_embedding, nodes, edges),
- }
- }
-
- /// Vectorise a subset of the graphs of the server.
- ///
- /// Arguments:
- /// graph_names (list[str]): the names of the graphs to vectorise. All by default.
+ /// name (list[str]): the name of the graph to vectorise.
/// nodes (bool | str): if nodes have to be embedded or not or the custom template to use if a str is provided. Defaults to True.
/// edges (bool | str): if edges have to be embedded or not or the custom template to use if a str is provided. Defaults to True.
///
/// Returns:
/// GraphServer: A new server object containing the vectorised graphs.
#[pyo3(
- signature = (graph_names, nodes = TemplateConfig::Bool(true), edges = TemplateConfig::Bool(true))
+ signature = (name, embeddings, nodes = TemplateConfig::Bool(true), edges = TemplateConfig::Bool(true))
)]
- fn with_vectorised_graphs(
- slf: PyRefMut,
- graph_names: Vec,
- // TODO: support more models by just providing a string, For example, "openai", here and in the VectorisedGraph API
+ fn vectorise_graph(
+ &self,
+ py: Python,
+ name: &str,
+ embeddings: PyOpenAIEmbeddings, // FIXME: this will create a breaking change once there are more options
nodes: TemplateConfig,
edges: TemplateConfig,
- ) -> PyResult {
+ ) -> PyResult<()> {
let template = template_from_python(nodes, edges).ok_or(PyAttributeError::new_err(
- "node_template and/or edge_template has to be set",
+ "at least one of nodes and edges has to be set to True or some string",
))?;
- let server = take_server_ownership(slf)?;
- Ok(server.with_vectorised_graphs(graph_names, template))
+ let rt = tokio::runtime::Runtime::new().unwrap();
+ // allow threads just in case the embedding server is using the same python runtime
+ py.allow_threads(|| {
+ rt.block_on(async move {
+ self.0
+ .vectorise_graph(name, template, embeddings.into())
+ .await?;
+ Ok(())
+ })
+ })
}
+ // TODO: vectorise all graphs
+
/// Start the server and return a handle to it.
///
/// Arguments:
@@ -229,16 +185,10 @@ impl PyGraphServer {
#[pyo3(
signature = (port = 1736, timeout_ms = 5000)
)]
- pub fn start(
- slf: PyRefMut,
- py: Python,
- port: u16,
- timeout_ms: u64,
- ) -> PyResult {
+ pub fn start(&self, py: Python, port: u16, timeout_ms: u64) -> PyResult {
let (sender, receiver) = crossbeam_channel::bounded::(1);
let cloned_sender = sender.clone();
-
- let server = take_server_ownership(slf)?;
+ let server = self.0.clone();
let join_handle = thread::spawn(move || {
let rt = tokio::runtime::Runtime::new().unwrap();
@@ -263,9 +213,7 @@ impl PyGraphServer {
let mut server = PyRunningGraphServer::new(join_handle, sender, port)?;
if let Some(_server_handler) = &server.server_handler {
let url = format!("http://localhost:{port}");
- // we need to release the GIL, otherwise the server will deadlock when trying to use python function as the embedding function
- // and wait_for_server_online will never return
- let result = py.allow_threads(|| server.wait_for_server_online(&url, timeout_ms));
+ let result = server.wait_for_server_online(&url, timeout_ms);
match result {
Ok(_) => return Ok(server),
Err(e) => {
@@ -289,8 +237,8 @@ impl PyGraphServer {
#[pyo3(
signature = (port = 1736, timeout_ms = 180000)
)]
- pub fn run(slf: PyRefMut, py: Python, port: u16, timeout_ms: u64) -> PyResult<()> {
- let mut server = Self::start(slf, py, port, timeout_ms)?.server_handler;
+ pub fn run(&self, py: Python, port: u16, timeout_ms: u64) -> PyResult<()> {
+ let mut server = self.start(py, port, timeout_ms)?.server_handler;
py.allow_threads(|| wait_server(&mut server))
}
}
diff --git a/raphtory-graphql/src/rayon.rs b/raphtory-graphql/src/rayon.rs
index 03fa57bb75..0d6b063c3b 100644
--- a/raphtory-graphql/src/rayon.rs
+++ b/raphtory-graphql/src/rayon.rs
@@ -62,7 +62,9 @@ mod deadlock_tests {
async fn test_pool_lock(port: u16, pool_lock: impl FnOnce(Arc>)) {
let tempdir = TempDir::new().unwrap();
- let server = GraphServer::new(tempdir.path().to_path_buf(), None, None).unwrap();
+ let server = GraphServer::new(tempdir.path().to_path_buf(), None, None)
+ .await
+ .unwrap();
let _running = server.start_with_port(port).await.unwrap();
tokio::time::sleep(Duration::from_secs(1)).await; // this is to wait for the server to be up
let lock = Arc::new(Mutex::new(()));
diff --git a/raphtory-graphql/src/server.rs b/raphtory-graphql/src/server.rs
index 2f4c9b43a3..c22c32ffe8 100644
--- a/raphtory-graphql/src/server.rs
+++ b/raphtory-graphql/src/server.rs
@@ -1,12 +1,13 @@
use crate::{
auth::{AuthenticatedGraphQL, MutationAuth},
config::app_config::{load_config, AppConfig},
- data::{Data, EmbeddingConf},
+ data::Data,
model::{
plugins::{entry_point::EntryPoint, operation::Operation},
App,
},
observability::open_telemetry::OpenTelemetry,
+ paths::ExistingGraphFolder,
routes::{health, version, PublicFilesEndpoint},
server::ServerError::SchemaError,
};
@@ -22,13 +23,14 @@ use poem::{
};
use raphtory::{
errors::GraphResult,
- vectors::{cache::VectorCache, embeddings::EmbeddingFunction, template::DocumentTemplate},
+ vectors::{
+ cache::{CachedEmbeddingModel, VectorCache},
+ storage::OpenAIEmbeddings,
+ template::DocumentTemplate,
+ },
};
use serde_json::json;
-use std::{
- fs::create_dir_all,
- path::{Path, PathBuf},
-};
+use std::{fs::create_dir_all, path::PathBuf};
use thiserror::Error;
use tokio::{
io,
@@ -78,6 +80,7 @@ impl From for io::Error {
}
/// A struct for defining and running a Raphtory GraphQL server
+#[derive(Clone)]
pub struct GraphServer {
data: Data,
config: AppConfig,
@@ -108,7 +111,7 @@ impl GraphServer {
///
/// Returns:
/// IoResult:
- pub fn new(
+ pub async fn new(
work_dir: PathBuf,
app_config: Option,
config_path: Option,
@@ -122,58 +125,57 @@ impl GraphServer {
}
/// Turn off index for all graphs
- pub fn turn_off_index(mut self) -> Self {
- self.data.create_index = false;
- self
- }
-
- pub async fn set_embeddings(
- mut self,
- embedding: F,
- cache: &Path,
- // or maybe it could be in a standard location like /tmp/raphtory/embedding_cache
- global_template: Option,
- ) -> GraphResult {
- self.data.embedding_conf = Some(EmbeddingConf {
- cache: VectorCache::on_disk(cache, embedding).await?, // TODO: better do this lazily, actually do it when running the server
- global_template,
- individual_templates: Default::default(),
- });
- Ok(self)
+ pub fn turn_off_index(&mut self) {
+ self.data.create_index = false; // FIXME: why does this exist yet?
}
- /// Vectorise a subset of the graphs of the server.
+ /// Vectorise all the graphs in the server working directory.
///
/// Arguments:
- /// * graph_names - the names of the graphs to vectorise. All if None is provided.
- /// * embedding - the embedding function to translate documents to embeddings.
- /// * cache - the directory to use as cache for the embeddings.
+ /// * name - the name of the graph to vectorise.
/// * template - the template to use for creating documents.
///
/// Returns:
/// A new server object containing the vectorised graphs.
- pub fn with_vectorised_graphs(
- mut self,
- graph_names: Vec,
- template: DocumentTemplate,
- ) -> Self {
- if let Some(embedding_conf) = &mut self.data.embedding_conf {
- for graph_name in graph_names {
- embedding_conf
- .individual_templates
- .insert(graph_name.into(), template.clone());
- }
+ pub async fn vectorise_all_graphs(
+ &self,
+ template: &DocumentTemplate,
+ embeddings: OpenAIEmbeddings,
+ ) -> GraphResult<()> {
+ let vector_cache = self.data.vector_cache.resolve().await?;
+ let model = vector_cache.openai(embeddings).await?;
+ for folder in self.data.get_all_graph_folders() {
+ self.data
+ .vectorise_folder(&folder, template, model.clone()) // TODO: avoid clone, just ask for a ref
+ .await?;
}
- self
+ Ok(())
+ }
+
+ /// Vectorise the graph 'name'in the server working directory.
+ ///
+ /// Arguments:
+ /// * path - the path of the graph to vectorise.
+ /// * template - the template to use for creating documents.
+ pub async fn vectorise_graph(
+ &self,
+ path: &str,
+ template: DocumentTemplate,
+ embeddings: OpenAIEmbeddings,
+ ) -> GraphResult<()> {
+ let vetor_cache = self.data.vector_cache.resolve();
+ let model = vetor_cache.await?.openai(embeddings).await?;
+ let folder = ExistingGraphFolder::try_from(self.data.work_dir.clone(), path)?;
+ self.data.vectorise_folder(&folder, &template, model).await
}
/// Start the server on the default port and return a handle to it.
- pub async fn start(self) -> IoResult {
+ pub async fn start(&self) -> IoResult {
self.start_with_port(DEFAULT_PORT).await
}
/// Start the server on the given port and return a handle to it.
- pub async fn start_with_port(self, port: u16) -> IoResult {
+ pub async fn start_with_port(&self, port: u16) -> IoResult {
// set up opentelemetry first of all
let config = self.config.clone();
let filter = config.logging.get_log_env();
@@ -197,7 +199,6 @@ impl GraphServer {
}
};
- self.data.vectorise_all_graphs_that_are_not().await?;
let work_dir = self.data.work_dir.clone();
// it is important that this runs after algorithms have been pushed to PLUGIN_ALGOS static variable
@@ -227,11 +228,11 @@ impl GraphServer {
}
async fn generate_endpoint(
- self,
+ &self,
tracer: Option,
) -> Result>, ServerError> {
let schema_builder = App::create_schema();
- let schema_builder = schema_builder.data(self.data);
+ let schema_builder = schema_builder.data(self.data.clone());
let schema_builder = schema_builder.extension(MutationAuth);
let trace_level = self.config.tracing.tracing_level.clone();
let schema = if let Some(t) = tracer {
@@ -247,8 +248,8 @@ impl GraphServer {
.nest(
"/",
PublicFilesEndpoint::new(
- self.config.public_dir,
- AuthenticatedGraphQL::new(schema, self.config.auth),
+ self.config.public_dir.clone(),
+ AuthenticatedGraphQL::new(schema, self.config.auth.clone()),
),
)
.at("/health", get(health))
@@ -340,7 +341,10 @@ mod server_tests {
use chrono::prelude::*;
use raphtory::{
prelude::{AdditionOps, Graph, StableEncode, NO_PROPS},
- vectors::{embeddings::EmbeddingResult, template::DocumentTemplate, Embedding},
+ vectors::{
+ embeddings::EmbeddingResult, storage::OpenAIEmbeddings, template::DocumentTemplate,
+ Embedding,
+ },
};
use raphtory_api::core::utils::logging::global_info_logger;
use tempfile::tempdir;
@@ -351,7 +355,9 @@ mod server_tests {
async fn test_server_start_stop() {
global_info_logger();
let tmp_dir = tempdir().unwrap();
- let server = GraphServer::new(tmp_dir.path().to_path_buf(), None, None).unwrap();
+ let server = GraphServer::new(tmp_dir.path().to_path_buf(), None, None)
+ .await
+ .unwrap();
info!("Calling start at time {}", Local::now());
let handler = server.start_with_port(0);
sleep(Duration::from_secs(1)).await;
@@ -359,16 +365,6 @@ mod server_tests {
handler.await.unwrap().stop().await
}
- #[derive(thiserror::Error, Debug)]
- enum SomeError {
- #[error("A variant of this error")]
- Variant,
- }
-
- async fn failing_embedding(_texts: Vec) -> EmbeddingResult> {
- Err(SomeError::Variant.into())
- }
-
#[tokio::test]
async fn test_server_start_with_failing_embedding() {
let tmp_dir = tempdir().unwrap();
@@ -377,17 +373,23 @@ mod server_tests {
graph.encode(tmp_dir.path().join("g")).unwrap();
global_info_logger();
- let server = GraphServer::new(tmp_dir.path().to_path_buf(), None, None).unwrap();
+ let server = GraphServer::new(tmp_dir.path().to_path_buf(), None, None)
+ .await
+ .unwrap();
let template = DocumentTemplate {
node_template: Some("{{ name }}".to_owned()),
..Default::default()
};
- let cache_dir = tempdir().unwrap();
- let handler = server
- .set_embeddings(failing_embedding, cache_dir.path(), Some(template))
- .await
- .unwrap()
- .start_with_port(0);
+ let model = OpenAIEmbeddings {
+ api_base: Some("wrong-api-base".to_owned()),
+ model: "whatever".to_owned(),
+ api_key_env: None,
+ project_id: None,
+ org_id: None,
+ };
+ let result = server.vectorise_all_graphs(&template, model).await;
+ assert!(result.is_err());
+ let handler = server.start_with_port(0);
sleep(Duration::from_secs(5)).await;
handler.await.unwrap().stop().await
}
diff --git a/raphtory/Cargo.toml b/raphtory/Cargo.toml
index 66a566b191..73870fb104 100644
--- a/raphtory/Cargo.toml
+++ b/raphtory/Cargo.toml
@@ -85,9 +85,11 @@ async-openai = { workspace = true, optional = true }
bincode = { workspace = true, optional = true }
minijinja = { workspace = true, optional = true }
minijinja-contrib = { workspace = true, optional = true }
-arroy = { workspace = true, optional = true }
heed = { workspace = true, optional = true }
moka = { workspace = true, optional = true }
+lancedb = {workspace = true, optional = true }
+arrow-array = { workspace = true, features = ["chrono-tz"], optional = true }
+axum = "0.8.4" # TODO put this in the proper place and make optional
# python binding optional dependencies
pyo3 = { workspace = true, optional = true }
@@ -140,9 +142,11 @@ vectors = [
"dep:minijinja",
"dep:minijinja-contrib",
"raphtory-api/template",
- "dep:arroy",
"dep:heed",
"dep:moka",
+ "dep:lancedb",
+ "dep:arrow-array",
+ "dep:tokio", # also used for the io feature
"dep:tempfile", # also used for the storage feature
]
diff --git a/raphtory/src/errors.rs b/raphtory/src/errors.rs
index a52c0fd8ef..fc2c0c698b 100644
--- a/raphtory/src/errors.rs
+++ b/raphtory/src/errors.rs
@@ -1,3 +1,5 @@
+#[cfg(feature = "vectors")]
+use crate::vectors::{embeddings::EmbeddingError, Embedding};
use crate::{
core::storage::lazy_vec::IllegalSet,
db::graph::views::filter::model::filter_operator::FilterOperator, prelude::GraphViewOps,
@@ -247,17 +249,21 @@ pub enum GraphError {
IOErrorMsg(String),
#[cfg(feature = "vectors")]
- #[error("Arroy error: {0}")]
- ArroyError(#[from] arroy::Error),
+ #[error("Heed error: {0}")]
+ HeedError(#[from] heed::Error),
#[cfg(feature = "vectors")]
#[error("Heed error: {0}")]
- HeedError(#[from] heed::Error),
+ LanceDbError(#[from] lancedb::Error),
#[cfg(feature = "vectors")]
#[error("The path {0} does not contain a vector DB")]
VectorDbDoesntExist(String),
+ #[cfg(feature = "vectors")]
+ #[error("The schema of the vector DB is invalid")]
+ InvalidVectorDbSchema,
+
#[cfg(feature = "proto")]
#[error("zip operation failed")]
ZipError {
@@ -294,9 +300,13 @@ pub enum GraphError {
#[error("Embedding operation failed")]
EmbeddingError {
#[from]
- source: Box,
+ source: EmbeddingError,
},
+ #[cfg(feature = "vectors")]
+ #[error("Embedding model sample changed from {0:?} to {1:?}")]
+ InvalidModelSample(Embedding, Embedding),
+
#[cfg(feature = "search")]
#[error("Index operation failed")]
QueryError {
diff --git a/raphtory/src/python/packages/base_modules.rs b/raphtory/src/python/packages/base_modules.rs
index 7d22e19fc2..eb3c917a09 100644
--- a/raphtory/src/python/packages/base_modules.rs
+++ b/raphtory/src/python/packages/base_modules.rs
@@ -31,7 +31,7 @@ use crate::{
algorithms::*,
graph_gen::*,
graph_loader::*,
- vectors::{PyVectorSelection, PyVectorisedGraph},
+ vectors::{embedding_server, PyOpenAIEmbeddings, PyVectorSelection, PyVectorisedGraph},
},
types::{
result_iterable::{
@@ -248,10 +248,15 @@ pub fn base_graph_gen_module(py: Python<'_>) -> Result, PyEr
pub fn base_vectors_module(py: Python<'_>) -> Result, PyErr> {
let vectors_module = PyModule::new(py, "vectors")?;
- vectors_module.add_class::()?;
- vectors_module.add_class::()?;
- vectors_module.add_class::()?;
- vectors_module.add_class::()?;
+ add_classes!(
+ &vectors_module,
+ PyVectorisedGraph,
+ PyDocument,
+ PyEmbedding,
+ PyVectorSelection,
+ PyOpenAIEmbeddings
+ );
+ add_functions!(&vectors_module, embedding_server);
Ok(vectors_module)
}
diff --git a/raphtory/src/python/packages/vectors.rs b/raphtory/src/python/packages/vectors.rs
index 9276f6741d..51771ef567 100644
--- a/raphtory/src/python/packages/vectors.rs
+++ b/raphtory/src/python/packages/vectors.rs
@@ -3,11 +3,12 @@ use crate::{
python::{
graph::{edge::PyEdge, node::PyNode, views::graph_view::PyGraphView},
types::wrappers::document::PyDocument,
- utils::{execute_async_task, PyNodeRef},
+ utils::{block_on, execute_async_task, PyNodeRef},
},
vectors::{
cache::VectorCache,
- embeddings::{EmbeddingFunction, EmbeddingResult},
+ custom::{serve_custom_embedding, EmbeddingFunction, EmbeddingServer},
+ storage::OpenAIEmbeddings,
template::{DocumentTemplate, DEFAULT_EDGE_TEMPLATE, DEFAULT_NODE_TEMPLATE},
vector_selection::DynamicVectorSelection,
vectorisable::Vectorisable,
@@ -15,10 +16,10 @@ use crate::{
Document, DocumentEntity, Embedding,
},
};
-use futures_util::future::BoxFuture;
+
use itertools::Itertools;
use pyo3::{
- exceptions::PyTypeError,
+ exceptions::{PyException, PyTypeError},
prelude::*,
types::{PyFunction, PyList},
};
@@ -26,10 +27,159 @@ use raphtory_api::core::{
storage::timeindex::{AsTime, EventTime},
utils::time::IntoTime,
};
-use std::path::PathBuf;
+use std::{path::PathBuf, sync::Arc};
+use tokio::runtime::Runtime;
type DynamicVectorisedGraph = VectorisedGraph;
+#[pyclass(name = "OpenAIEmbeddings")]
+#[derive(Clone)]
+pub struct PyOpenAIEmbeddings {
+ model: String,
+ api_base: Option,
+ api_key_env: Option,
+ org_id: Option,
+ project_id: Option,
+}
+
+#[pymethods]
+impl PyOpenAIEmbeddings {
+ #[new]
+ #[pyo3(signature = (model="text-embedding-3-small", api_base=None, api_key_env=None, org_id=None, project_id=None))]
+ fn new(
+ model: &str,
+ api_base: Option,
+ api_key_env: Option,
+ org_id: Option,
+ project_id: Option,
+ ) -> Self {
+ Self {
+ model: model.to_owned(),
+ api_base,
+ api_key_env,
+ org_id,
+ project_id,
+ }
+ }
+}
+impl From for OpenAIEmbeddings {
+ fn from(value: PyOpenAIEmbeddings) -> Self {
+ Self {
+ model: value.model.clone(),
+ api_base: value.api_base.clone(),
+ api_key_env: value.api_key_env.clone(),
+ org_id: value.org_id.clone(),
+ project_id: value.project_id.clone(),
+ }
+ }
+}
+
+impl EmbeddingFunction for Arc> {
+ fn call(&self, text: &str) -> Vec {
+ Python::with_gil(|py| {
+ // TODO: remove unwraps?
+ let any = self.call1(py, (text,)).unwrap();
+ let list = any.downcast_bound::(py).unwrap();
+ list.iter().map(|value| value.extract().unwrap()).collect()
+ })
+ }
+}
+
+#[pyfunction]
+pub fn embedding_server(address: String) -> EmbeddingServerDecorator {
+ EmbeddingServerDecorator { address }
+}
+
+#[pyclass]
+struct EmbeddingServerDecorator {
+ address: String,
+}
+
+#[pymethods]
+impl EmbeddingServerDecorator {
+ fn __call__(&self, function: Py) -> PyEmbeddingServer {
+ PyEmbeddingServer {
+ function: function.into(),
+ address: self.address.clone(),
+ }
+ }
+}
+
+// struct RunningServer {
+// runtime: Runtime,
+// server: EmbeddingServer,
+// }
+
+#[pyclass(name = "EmbeddingServer")]
+pub struct PyEmbeddingServer {
+ function: Arc>,
+ address: String,
+ // running: Option, // TODO: use all of these ideas for the GraphServer implementation
+}
+// TODO: ideally, I should allow users to provide this server object as embedding model, so the fact it has an OpenAI like API is transparent to the user
+
+impl PyEmbeddingServer {
+ fn create_running_server(&self) -> (Runtime, EmbeddingServer) {
+ let runtime = tokio::runtime::Runtime::new().unwrap();
+ // let runtime = tokio::runtime::Builder::new_multi_thread()
+ // .enable_all()
+ // .build()
+ // .unwrap();
+ let execution =
+ runtime.block_on(serve_custom_embedding(&self.address, self.function.clone()));
+ (runtime, execution)
+ }
+}
+
+#[pymethods]
+impl PyEmbeddingServer {
+ fn run(&self) {
+ let (runtime, execution) = self.create_running_server();
+ runtime.block_on(execution.wait());
+ }
+
+ fn start(&self) -> PyRunningEmbeddingServer {
+ let (runtime, execution) = self.create_running_server();
+ PyRunningEmbeddingServer {
+ runtime,
+ execution: Some(execution),
+ }
+ }
+}
+
+#[pyclass(name = "RunningEmbeddingServer")]
+struct PyRunningEmbeddingServer {
+ runtime: Runtime,
+ execution: Option, // TODO: rename EmbeddingServer to ServerHandle?
+}
+
+#[pymethods]
+impl PyRunningEmbeddingServer {
+ fn stop(&mut self) -> PyResult<()> {
+ if let Some(execution) = &mut self.execution {
+ self.runtime.block_on(execution.stop());
+ self.execution = None;
+ Ok(())
+ } else {
+ Err(PyException::new_err("Embedding server was already stopped"))
+ }
+ }
+
+ fn __enter__(slf: Py) -> Py {
+ slf
+ }
+
+ fn __exit__(
+ &mut self,
+ // py: Python,
+ _exc_type: PyObject,
+ _exc_val: PyObject,
+ _exc_tb: PyObject,
+ ) -> PyResult<()> {
+ self.stop()
+ }
+}
+
pub type PyWindow = Option<(EventTime, EventTime)>;
pub fn translate_window(window: PyWindow) -> Option<(i64, i64)> {
@@ -49,9 +199,9 @@ impl PyQuery {
) -> PyResult {
match self {
Self::Raw(query) => {
- let cache = graph.cache.clone();
+ let graph = graph.clone();
let result = Ok(execute_async_task(move || async move {
- cache.get_single(query).await
+ graph.embed_text(query).await
})?);
result
}
@@ -153,7 +303,7 @@ impl PyGraphView {
#[pyo3(signature = (embedding, nodes = TemplateConfig::Bool(true), edges = TemplateConfig::Bool(true), cache = None, verbose = false))]
fn vectorise(
&self,
- embedding: Bound,
+ embedding: PyOpenAIEmbeddings,
nodes: TemplateConfig,
edges: TemplateConfig,
cache: Option,
@@ -163,15 +313,15 @@ impl PyGraphView {
node_template: nodes.get_template_or(DEFAULT_NODE_TEMPLATE),
edge_template: edges.get_template_or(DEFAULT_EDGE_TEMPLATE),
};
- let embedding = embedding.unbind();
let graph = self.graph.clone();
execute_async_task(move || async move {
let cache = if let Some(cache) = cache {
- VectorCache::on_disk(&PathBuf::from(cache), embedding).await?
+ VectorCache::on_disk(&PathBuf::from(cache)).await?
} else {
- VectorCache::in_memory(embedding)
+ VectorCache::in_memory()
};
- Ok(graph.vectorise(cache, template, None, verbose).await?)
+ let model = cache.openai(embedding.into()).await?;
+ Ok(graph.vectorise(model, template, None, verbose).await?)
})
}
}
@@ -216,15 +366,20 @@ impl<'py> IntoPyObject<'py> for DynamicVectorSelection {
/// of those documents using a query and similarity scores.
#[pymethods]
impl PyVectorisedGraph {
+ /// Optmize the vector index
+ fn optimize_index(&self) -> PyResult<()> {
+ Ok(block_on(self.0.optimize_index())?)
+ }
+
/// Return an empty selection of entities.
fn empty_selection(&self) -> DynamicVectorSelection {
self.0.empty_selection()
}
- /// Perform a similarity search between each entity's associated document and a specified `query`. Returns a number of entities up to a specified `limit` ranked in descending order of similarity score.
+ /// Perform a similarity search between each entity's associated document and a specified `query`. Returns a number of entities up to a specified `limit` ranked in ascending order of distance.
///
/// Args:
- /// query (str | list): The text or the embedding to score against.
+ /// query (str | list): The text or the embedding to calculate the distance from.
/// limit (int): The maximum number of new entities in the result.
/// window (Tuple[int | str, int | str], optional): The window that documents need to belong to in order to be considered.
///
@@ -238,17 +393,17 @@ impl PyVectorisedGraph {
window: PyWindow,
) -> PyResult {
let embedding = query.into_embedding(&self.0)?;
- Ok(self
- .0
- .entities_by_similarity(&embedding, limit, translate_window(window))?)
+ let w = translate_window(window);
+ let s = block_on(self.0.entities_by_similarity(&embedding, limit, w))?;
+ Ok(s)
}
- /// Perform a similarity search between each node's associated document and a specified `query`. Returns a number of nodes up to a specified `limit` ranked in descending order of similarity score.
+ /// Perform a similarity search between each node's associated document and a specified `query`. Returns a number of nodes up to a specified `limit` ranked in ascending order of distance.
///
/// Args:
- /// query (str | list): The text or the embedding to score against.
+ /// query (str | list): The text or the embedding to calculate the distance from.
/// limit (int): The maximum number of new nodes in the result.
- /// window (Tuple[int | str, int | str], optional): The window that documents need to belong to in order to be considered.
+ /// window (Tuple[int | str, int | str], optional): The window where documents need to belong to in order to be considered.
///
/// Returns:
/// VectorSelection: The vector selection resulting from the search.
@@ -260,15 +415,14 @@ impl PyVectorisedGraph {
window: PyWindow,
) -> PyResult {
let embedding = query.into_embedding(&self.0)?;
- Ok(self
- .0
- .nodes_by_similarity(&embedding, limit, translate_window(window))?)
+ let w = translate_window(window);
+ Ok(block_on(self.0.nodes_by_similarity(&embedding, limit, w))?)
}
- /// Perform a similarity search between each edge's associated document and a specified `query`. Returns a number of edges up to a specified `limit` ranked in descending order of similarity score.
+ /// Perform a similarity search between each edge's associated document and a specified `query`. Returns a number of edges up to a specified `limit` ranked in ascending order of distance.
///
/// Args:
- /// query (str | list): The text or the embedding to score against.
+ /// query (str | list): The text or the embedding to calculate the distance from.
/// limit (int): The maximum number of new edges in the results.
/// window (Tuple[int | str, int | str], optional): The window that documents need to belong to in order to be considered.
///
@@ -282,9 +436,8 @@ impl PyVectorisedGraph {
window: PyWindow,
) -> PyResult {
let embedding = query.into_embedding(&self.0)?;
- Ok(self
- .0
- .edges_by_similarity(&embedding, limit, translate_window(window))?)
+ let w = translate_window(window);
+ Ok(block_on(self.0.edges_by_similarity(&embedding, limit, w))?)
}
}
@@ -324,20 +477,20 @@ impl PyVectorSelection {
/// Returns:
/// list[Document]: List of documents in the current selection.
fn get_documents(&self) -> PyResult>> {
- Ok(self.0.get_documents()?)
+ Ok(block_on(self.0.get_documents())?)
}
- /// Returns the documents present in the current selection alongside their scores.
+ /// Returns the documents present in the current selection alongside their distances.
///
/// Returns:
- /// list[Tuple[Document, float]]: List of documents and scores.
- fn get_documents_with_scores(&self) -> PyResult, f32)>> {
- Ok(self.0.get_documents_with_scores()?)
+ /// list[Tuple[Document, float]]: List of documents and distances.
+ fn get_documents_with_distances(&self) -> PyResult, f32)>> {
+ Ok(block_on(self.0.get_documents_with_distances())?)
}
/// Add all the documents associated with the specified `nodes` to the current selection.
///
- /// Documents added by this call are assumed to have a score of 0.
+ /// Documents added by this call are assumed to have a distance of 0.
///
/// Args:
/// nodes (list): List of the node ids or nodes to add.
@@ -350,7 +503,7 @@ impl PyVectorSelection {
/// Add all the documents associated with the specified `edges` to the current selection.
///
- /// Documents added by this call are assumed to have a score of 0.
+ /// Documents added by this call are assumed to have a distance of 0.
///
/// Args:
/// edges (list): List of the edge ids or edges to add.
@@ -390,20 +543,19 @@ impl PyVectorSelection {
self_.0.expand(hops, translate_window(window))
}
- /// Add the top `limit` adjacent entities with higher score for `query` to the selection
+ /// Add to the selection the `limit` adjacent entities closest to `query`
///
/// The expansion algorithm is a loop with two steps on each iteration:
///
/// 1. All the entities 1 hop away of some of the entities included on the selection (and
/// not already selected) are marked as candidates.
- /// 2. Those candidates are added to the selection in descending order according to the
- /// similarity score obtained against the `query`.
+ /// 2. Those candidates are added to the selection in ascending distance from `query`.
///
/// This loops goes on until the number of new entities reaches a total of `limit`
/// entities or until no more documents are available
///
/// Args:
- /// query (str | list): The text or the embedding to score against.
+ /// query (str | list): The text or the embedding to calculate the distance from.
/// limit (int): The number of documents to add.
/// window (Tuple[int | str, int | str], optional): The window that documents need to belong to in order to be considered.
///
@@ -411,24 +563,24 @@ impl PyVectorSelection {
/// None:
#[pyo3(signature = (query, limit, window=None))]
fn expand_entities_by_similarity(
- mut self_: PyRefMut<'_, Self>,
+ mut slf: PyRefMut<'_, Self>,
query: PyQuery,
limit: usize,
window: PyWindow,
) -> PyResult<()> {
- let embedding = query.into_embedding(&self_.0.graph)?;
- self_
- .0
- .expand_entities_by_similarity(&embedding, limit, translate_window(window))?;
+ let embedding = query.into_embedding(&slf.0.graph)?;
+ let w = translate_window(window);
+ block_on(slf.0.expand_entities_by_similarity(&embedding, limit, w))?;
+
Ok(())
}
- /// Add the top `limit` adjacent nodes with higher score for `query` to the selection
+ /// Add to the selection the `limit` adjacent nodes closest to `query`
///
/// This function has the same behaviour as expand_entities_by_similarity but it only considers nodes.
///
/// Args:
- /// query (str | list): The text or the embedding to score against.
+ /// query (str | list): The text or the embedding to calculate the distance from.
/// limit (int): The maximum number of new nodes to add.
/// window (Tuple[int | str, int | str], optional): The window that documents need to belong to in order to be considered.
///
@@ -436,24 +588,23 @@ impl PyVectorSelection {
/// None:
#[pyo3(signature = (query, limit, window=None))]
fn expand_nodes_by_similarity(
- mut self_: PyRefMut<'_, Self>,
+ mut slf: PyRefMut<'_, Self>,
query: PyQuery,
limit: usize,
window: PyWindow,
) -> PyResult<()> {
- let embedding = query.into_embedding(&self_.0.graph)?;
- self_
- .0
- .expand_nodes_by_similarity(&embedding, limit, translate_window(window))?;
+ let embedding = query.into_embedding(&slf.0.graph)?;
+ let w = translate_window(window);
+ block_on(slf.0.expand_nodes_by_similarity(&embedding, limit, w))?;
Ok(())
}
- /// Add the top `limit` adjacent edges with higher score for `query` to the selection
+ /// Add to the selection the `limit` adjacent edges closest to `query`
///
/// This function has the same behaviour as expand_entities_by_similarity but it only considers edges.
///
/// Args:
- /// query (str | list): The text or the embedding to score against.
+ /// query (str | list): The text or the embedding to calculate the distance from.
/// limit (int): The maximum number of new edges to add.
/// window (Tuple[int | str, int | str], optional): The window that documents need to belong to in order to be considered.
///
@@ -461,48 +612,14 @@ impl PyVectorSelection {
/// None:
#[pyo3(signature = (query, limit, window=None))]
fn expand_edges_by_similarity(
- mut self_: PyRefMut<'_, Self>,
+ mut slf: PyRefMut<'_, Self>,
query: PyQuery,
limit: usize,
window: PyWindow,
) -> PyResult<()> {
- let embedding = query.into_embedding(&self_.0.graph)?;
- self_
- .0
- .expand_edges_by_similarity(&embedding, limit, translate_window(window))?;
+ let embedding = query.into_embedding(&slf.0.graph)?;
+ let w = translate_window(window);
+ block_on(slf.0.expand_edges_by_similarity(&embedding, limit, w))?;
Ok(())
}
}
-
-impl EmbeddingFunction for Py {
- fn call(&self, texts: Vec) -> BoxFuture<'static, EmbeddingResult>> {
- let embedding_function = Python::with_gil(|py| self.clone_ref(py));
- Box::pin(async move {
- Python::with_gil(|py| {
- let embedding_function = embedding_function.bind(py);
- let python_texts = PyList::new(py, texts)?;
- let result = embedding_function.call1((python_texts,))?;
- let embeddings = result.downcast::().map_err(|_| {
- PyTypeError::new_err(
- "value returned by the embedding function was not a python list",
- )
- })?;
-
- let embeddings: EmbeddingResult> = embeddings
- .iter()
- .map(|embedding| {
- let pylist = embedding.downcast::().map_err(|_| {
- PyTypeError::new_err("one of the values in the list returned by the embedding function was not a python list")
- })?;
- let embedding: EmbeddingResult = pylist
- .iter()
- .map(|element| Ok(element.extract::()?))
- .collect();
- embedding
- })
- .collect();
- embeddings
- })
- })
- }
-}
diff --git a/raphtory/src/python/utils/mod.rs b/raphtory/src/python/utils/mod.rs
index 4ab0995b45..7660b62b08 100644
--- a/raphtory/src/python/utils/mod.rs
+++ b/raphtory/src/python/utils/mod.rs
@@ -411,7 +411,7 @@ impl<'py> IntoPyObject<'py> for NumpyArray {
// This function takes a function that returns a future instead of taking just a future because
// a task might return an unsendable future but what we can do is making a function returning that
// future which is sendable itself
-pub fn execute_async_task(task: T) -> O
+pub(crate) fn execute_async_task(task: T) -> O
where
T: FnOnce() -> F + Send + 'static,
F: Future