From 67b22adbce825fb280190d0ead5cf11d811355e3 Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 15 Oct 2024 20:28:09 +0200 Subject: [PATCH 1/4] Implement a sync run method - experiment 1 --- examples/pipeline/kg_builder_example.py | 0 .../experimental/pipeline/kg_builder.py | 5 ++++ src/neo4j_graphrag/utils.py | 25 +++++++++++++++++++ 3 files changed, 30 insertions(+) create mode 100644 examples/pipeline/kg_builder_example.py diff --git a/examples/pipeline/kg_builder_example.py b/examples/pipeline/kg_builder_example.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py index 3fca0215a..e459fb25d 100644 --- a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py @@ -34,6 +34,7 @@ ) from neo4j_graphrag.generation.prompts import ERExtractionTemplate from neo4j_graphrag.llm.base import LLMInterface +from neo4j_graphrag.utils import run_sync class SimpleKGPipeline: @@ -124,3 +125,7 @@ async def run_async( PipelineResult: The result of the pipeline execution. """ return await self.runner.run({"file_path": file_path, "text": text}) + + def run(self, file_path: Optional[str] = None, text: Optional[str] = None) -> PipelineResult: + """Run pipeline synchronously""" + return run_sync(self, file_path=file_path, text=text) diff --git a/src/neo4j_graphrag/utils.py b/src/neo4j_graphrag/utils.py index e86f7588a..60e4130d3 100644 --- a/src/neo4j_graphrag/utils.py +++ b/src/neo4j_graphrag/utils.py @@ -15,6 +15,8 @@ from __future__ import annotations from typing import Optional +import asyncio +import concurrent.futures def validate_search_query_input( @@ -22,3 +24,26 @@ def validate_search_query_input( ) -> None: if not (bool(query_vector) ^ bool(query_text)): raise ValueError("You must provide exactly one of query_vector or query_text.") + + +def run_sync(function, *args, **kwargs): + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(lambda: asyncio.run(function(*args, **kwargs))) + return_value = future.result() + return return_value + + +if __name__ == "__main__": + async def async_run(char: str, repeat: int = 2) -> str: + await asyncio.sleep(5) + return char * repeat + + async def async_run_multiple(char, n=10): + return await asyncio.gather(*[ + async_run(char) + for _ in range(n) + ]) + + print( + run_sync(async_run_multiple, "abc") + ) From 1efbe8e7edfb28f28f940209db423b7c368e6090 Mon Sep 17 00:00:00 2001 From: estelle Date: Wed, 16 Oct 2024 16:16:11 +0200 Subject: [PATCH 2/4] async_to_sync wrapper --- .../experimental/components/resolver.py | 3 +++ .../experimental/pipeline/component.py | 3 +++ .../experimental/pipeline/pipeline.py | 6 ++++++ src/neo4j_graphrag/utils.py | 20 ++++++------------- 4 files changed, 18 insertions(+), 14 deletions(-) diff --git a/src/neo4j_graphrag/experimental/components/resolver.py b/src/neo4j_graphrag/experimental/components/resolver.py index f2da0bff5..8d05d578b 100644 --- a/src/neo4j_graphrag/experimental/components/resolver.py +++ b/src/neo4j_graphrag/experimental/components/resolver.py @@ -19,6 +19,7 @@ from neo4j_graphrag.experimental.components.types import ResolutionStats from neo4j_graphrag.experimental.pipeline import Component +from neo4j_graphrag.utils import async_to_sync class EntityResolver(Component, abc.ABC): @@ -136,3 +137,5 @@ async def run(self) -> ResolutionStats: number_of_nodes_to_resolve=number_of_nodes_to_resolve, number_of_created_nodes=number_of_created_nodes, ) + + run_sync = async_to_sync(run) diff --git a/src/neo4j_graphrag/experimental/pipeline/component.py b/src/neo4j_graphrag/experimental/pipeline/component.py index 84cd5bc0c..efbd9bcb9 100644 --- a/src/neo4j_graphrag/experimental/pipeline/component.py +++ b/src/neo4j_graphrag/experimental/pipeline/component.py @@ -21,6 +21,7 @@ from pydantic import BaseModel from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError +from neo4j_graphrag.utils import async_to_sync class DataModel(BaseModel): @@ -63,6 +64,8 @@ def __new__( } for f, field in return_model.model_fields.items() } + # create sync method: + attrs["run_sync"] = async_to_sync(run_method) return type.__new__(meta, name, bases, attrs) diff --git a/src/neo4j_graphrag/experimental/pipeline/pipeline.py b/src/neo4j_graphrag/experimental/pipeline/pipeline.py index e3ded494d..3ab1d12e0 100644 --- a/src/neo4j_graphrag/experimental/pipeline/pipeline.py +++ b/src/neo4j_graphrag/experimental/pipeline/pipeline.py @@ -24,6 +24,8 @@ from timeit import default_timer from typing import Any, AsyncGenerator, Optional +from neo4j_graphrag.utils import async_to_sync + try: import pygraphviz as pgv except ImportError: @@ -107,6 +109,8 @@ async def run(self, inputs: dict[str, Any]) -> RunResult | None: logger.debug(f"TASK RESULT {self.name=} {res=}") return res + run_sync = async_to_sync(run) + class Orchestrator: """Orchestrate a pipeline. @@ -638,3 +642,5 @@ async def run(self, data: dict[str, Any]) -> PipelineResult: run_id=orchestrator.run_id, result=await self.final_results.get(orchestrator.run_id), ) + + run_sync = async_to_sync(run) diff --git a/src/neo4j_graphrag/utils.py b/src/neo4j_graphrag/utils.py index 60e4130d3..1962630fe 100644 --- a/src/neo4j_graphrag/utils.py +++ b/src/neo4j_graphrag/utils.py @@ -14,6 +14,7 @@ # limitations under the License. from __future__ import annotations +from functools import wraps from typing import Optional import asyncio import concurrent.futures @@ -33,17 +34,8 @@ def run_sync(function, *args, **kwargs): return return_value -if __name__ == "__main__": - async def async_run(char: str, repeat: int = 2) -> str: - await asyncio.sleep(5) - return char * repeat - - async def async_run_multiple(char, n=10): - return await asyncio.gather(*[ - async_run(char) - for _ in range(n) - ]) - - print( - run_sync(async_run_multiple, "abc") - ) +def async_to_sync(func): + @wraps(func) + def wrapper(*args, **kwargs): + return run_sync(func, *args, **kwargs) + return wrapper From 5d3b17e631c8176e9c214a630bdcd96978a933ce Mon Sep 17 00:00:00 2001 From: estelle Date: Thu, 2 Jan 2025 09:59:05 +0100 Subject: [PATCH 3/4] First argument must be the function + ruff --- src/neo4j_graphrag/experimental/pipeline/kg_builder.py | 6 ++++-- src/neo4j_graphrag/utils.py | 5 +++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py index e459fb25d..34d2a3fdd 100644 --- a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py @@ -126,6 +126,8 @@ async def run_async( """ return await self.runner.run({"file_path": file_path, "text": text}) - def run(self, file_path: Optional[str] = None, text: Optional[str] = None) -> PipelineResult: + def run( + self, file_path: Optional[str] = None, text: Optional[str] = None + ) -> PipelineResult: """Run pipeline synchronously""" - return run_sync(self, file_path=file_path, text=text) + return run_sync(self.run_async, file_path=file_path, text=text) diff --git a/src/neo4j_graphrag/utils.py b/src/neo4j_graphrag/utils.py index 1962630fe..20901d2c8 100644 --- a/src/neo4j_graphrag/utils.py +++ b/src/neo4j_graphrag/utils.py @@ -14,10 +14,10 @@ # limitations under the License. from __future__ import annotations -from functools import wraps -from typing import Optional import asyncio import concurrent.futures +from functools import wraps +from typing import Optional def validate_search_query_input( @@ -38,4 +38,5 @@ def async_to_sync(func): @wraps(func) def wrapper(*args, **kwargs): return run_sync(func, *args, **kwargs) + return wrapper From 0a6eb05a492ac094b0fa663c84c92ce28ddfb668 Mon Sep 17 00:00:00 2001 From: estelle Date: Thu, 2 Jan 2025 10:56:36 +0100 Subject: [PATCH 4/4] Adds sync mode to runner and update example --- .../simple_kg_pipeline_from_config_file.py | 20 ++++++++++++++++--- .../experimental/pipeline/config/runner.py | 4 ++++ 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/examples/build_graph/from_config_files/simple_kg_pipeline_from_config_file.py b/examples/build_graph/from_config_files/simple_kg_pipeline_from_config_file.py index 62ba6c85e..ee3d2f27f 100644 --- a/examples/build_graph/from_config_files/simple_kg_pipeline_from_config_file.py +++ b/examples/build_graph/from_config_files/simple_kg_pipeline_from_config_file.py @@ -22,6 +22,8 @@ logging.basicConfig() logging.getLogger("neo4j_graphrag").setLevel(logging.DEBUG) +ASYNC = False + os.environ["NEO4J_URI"] = "bolt://localhost:7687" os.environ["NEO4J_USER"] = "neo4j" os.environ["NEO4J_PASSWORD"] = "password" @@ -38,10 +40,22 @@ an aristocratic family that rules the planet Caladan, the rainy planet, since 10191.""" -async def main() -> PipelineResult: - pipeline = PipelineRunner.from_config_file(file_path) +def get_pipeline(file_path): + return PipelineRunner.from_config_file(file_path) + + +async def main(file_path) -> PipelineResult: + pipeline = get_pipeline(file_path) return await pipeline.run({"text": TEXT}) +def main_sync(file_path) -> PipelineResult: + pipeline = get_pipeline(file_path) + return pipeline.run_sync({"text": TEXT}) + + if __name__ == "__main__": - print(asyncio.run(main())) + if ASYNC: + print(asyncio.run(main(file_path))) + else: + print(main_sync(file_path)) diff --git a/src/neo4j_graphrag/experimental/pipeline/config/runner.py b/src/neo4j_graphrag/experimental/pipeline/config/runner.py index a1a225858..eb3fe3564 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/runner.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/runner.py @@ -48,6 +48,7 @@ from neo4j_graphrag.experimental.pipeline.config.types import PipelineType from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult from neo4j_graphrag.experimental.pipeline.types import PipelineDefinition +from neo4j_graphrag.utils import async_to_sync logger = logging.getLogger(__name__) @@ -130,3 +131,6 @@ async def close(self) -> None: logger.debug("PIPELINE_RUNNER: cleaning up (closing instantiated drivers...)") if self.config: await self.config.close() + + run_sync = async_to_sync(run) + close_sync = async_to_sync(close)