diff --git a/examples/build_graph/from_config_files/simple_kg_pipeline_from_config_file.py b/examples/build_graph/from_config_files/simple_kg_pipeline_from_config_file.py index 62ba6c85e..ee3d2f27f 100644 --- a/examples/build_graph/from_config_files/simple_kg_pipeline_from_config_file.py +++ b/examples/build_graph/from_config_files/simple_kg_pipeline_from_config_file.py @@ -22,6 +22,8 @@ logging.basicConfig() logging.getLogger("neo4j_graphrag").setLevel(logging.DEBUG) +ASYNC = False + os.environ["NEO4J_URI"] = "bolt://localhost:7687" os.environ["NEO4J_USER"] = "neo4j" os.environ["NEO4J_PASSWORD"] = "password" @@ -38,10 +40,22 @@ an aristocratic family that rules the planet Caladan, the rainy planet, since 10191.""" -async def main() -> PipelineResult: - pipeline = PipelineRunner.from_config_file(file_path) +def get_pipeline(file_path): + return PipelineRunner.from_config_file(file_path) + + +async def main(file_path) -> PipelineResult: + pipeline = get_pipeline(file_path) return await pipeline.run({"text": TEXT}) +def main_sync(file_path) -> PipelineResult: + pipeline = get_pipeline(file_path) + return pipeline.run_sync({"text": TEXT}) + + if __name__ == "__main__": - print(asyncio.run(main())) + if ASYNC: + print(asyncio.run(main(file_path))) + else: + print(main_sync(file_path)) diff --git a/examples/pipeline/kg_builder_example.py b/examples/pipeline/kg_builder_example.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/neo4j_graphrag/experimental/components/resolver.py b/src/neo4j_graphrag/experimental/components/resolver.py index f2da0bff5..8d05d578b 100644 --- a/src/neo4j_graphrag/experimental/components/resolver.py +++ b/src/neo4j_graphrag/experimental/components/resolver.py @@ -19,6 +19,7 @@ from neo4j_graphrag.experimental.components.types import ResolutionStats from neo4j_graphrag.experimental.pipeline import Component +from neo4j_graphrag.utils import async_to_sync class EntityResolver(Component, abc.ABC): @@ -136,3 +137,5 @@ async def run(self) -> ResolutionStats: number_of_nodes_to_resolve=number_of_nodes_to_resolve, number_of_created_nodes=number_of_created_nodes, ) + + run_sync = async_to_sync(run) diff --git a/src/neo4j_graphrag/experimental/pipeline/component.py b/src/neo4j_graphrag/experimental/pipeline/component.py index 84cd5bc0c..efbd9bcb9 100644 --- a/src/neo4j_graphrag/experimental/pipeline/component.py +++ b/src/neo4j_graphrag/experimental/pipeline/component.py @@ -21,6 +21,7 @@ from pydantic import BaseModel from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError +from neo4j_graphrag.utils import async_to_sync class DataModel(BaseModel): @@ -63,6 +64,8 @@ def __new__( } for f, field in return_model.model_fields.items() } + # create sync method: + attrs["run_sync"] = async_to_sync(run_method) return type.__new__(meta, name, bases, attrs) diff --git a/src/neo4j_graphrag/experimental/pipeline/config/runner.py b/src/neo4j_graphrag/experimental/pipeline/config/runner.py index a1a225858..eb3fe3564 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/runner.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/runner.py @@ -48,6 +48,7 @@ from neo4j_graphrag.experimental.pipeline.config.types import PipelineType from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult from neo4j_graphrag.experimental.pipeline.types import PipelineDefinition +from neo4j_graphrag.utils import async_to_sync logger = logging.getLogger(__name__) @@ -130,3 +131,6 @@ async def close(self) -> None: logger.debug("PIPELINE_RUNNER: cleaning up (closing instantiated drivers...)") if self.config: await self.config.close() + + run_sync = async_to_sync(run) + close_sync = async_to_sync(close) diff --git a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py index 3fca0215a..34d2a3fdd 100644 --- a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py @@ -34,6 +34,7 @@ ) from neo4j_graphrag.generation.prompts import ERExtractionTemplate from neo4j_graphrag.llm.base import LLMInterface +from neo4j_graphrag.utils import run_sync class SimpleKGPipeline: @@ -124,3 +125,9 @@ async def run_async( PipelineResult: The result of the pipeline execution. """ return await self.runner.run({"file_path": file_path, "text": text}) + + def run( + self, file_path: Optional[str] = None, text: Optional[str] = None + ) -> PipelineResult: + """Run pipeline synchronously""" + return run_sync(self.run_async, file_path=file_path, text=text) diff --git a/src/neo4j_graphrag/experimental/pipeline/pipeline.py b/src/neo4j_graphrag/experimental/pipeline/pipeline.py index e3ded494d..3ab1d12e0 100644 --- a/src/neo4j_graphrag/experimental/pipeline/pipeline.py +++ b/src/neo4j_graphrag/experimental/pipeline/pipeline.py @@ -24,6 +24,8 @@ from timeit import default_timer from typing import Any, AsyncGenerator, Optional +from neo4j_graphrag.utils import async_to_sync + try: import pygraphviz as pgv except ImportError: @@ -107,6 +109,8 @@ async def run(self, inputs: dict[str, Any]) -> RunResult | None: logger.debug(f"TASK RESULT {self.name=} {res=}") return res + run_sync = async_to_sync(run) + class Orchestrator: """Orchestrate a pipeline. @@ -638,3 +642,5 @@ async def run(self, data: dict[str, Any]) -> PipelineResult: run_id=orchestrator.run_id, result=await self.final_results.get(orchestrator.run_id), ) + + run_sync = async_to_sync(run) diff --git a/src/neo4j_graphrag/utils.py b/src/neo4j_graphrag/utils.py index e86f7588a..20901d2c8 100644 --- a/src/neo4j_graphrag/utils.py +++ b/src/neo4j_graphrag/utils.py @@ -14,6 +14,9 @@ # limitations under the License. from __future__ import annotations +import asyncio +import concurrent.futures +from functools import wraps from typing import Optional @@ -22,3 +25,18 @@ def validate_search_query_input( ) -> None: if not (bool(query_vector) ^ bool(query_text)): raise ValueError("You must provide exactly one of query_vector or query_text.") + + +def run_sync(function, *args, **kwargs): + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(lambda: asyncio.run(function(*args, **kwargs))) + return_value = future.result() + return return_value + + +def async_to_sync(func): + @wraps(func) + def wrapper(*args, **kwargs): + return run_sync(func, *args, **kwargs) + + return wrapper