Skip to content

Commit 4a37035

Browse files
committed
async_to_sync wrapper
1 parent b0ee1e4 commit 4a37035

File tree

4 files changed

+18
-16
lines changed

4 files changed

+18
-16
lines changed

src/neo4j_graphrag/experimental/components/resolver.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from neo4j_graphrag.experimental.components.types import ResolutionStats
2121
from neo4j_graphrag.experimental.pipeline import Component
22-
from neo4j_graphrag.utils import execute_query
22+
from neo4j_graphrag.utils import execute_query, async_to_sync
2323

2424

2525
class EntityResolver(Component, abc.ABC):
@@ -140,3 +140,5 @@ async def run(self) -> ResolutionStats:
140140
number_of_nodes_to_resolve=number_of_nodes_to_resolve,
141141
number_of_created_nodes=number_of_created_nodes,
142142
)
143+
144+
run_sync = async_to_sync(run)

src/neo4j_graphrag/experimental/pipeline/component.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from pydantic import BaseModel
2222

2323
from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError
24+
from neo4j_graphrag.utils import async_to_sync
2425

2526

2627
class DataModel(BaseModel):
@@ -63,6 +64,8 @@ def __new__(
6364
}
6465
for f, field in return_model.model_fields.items()
6566
}
67+
# create sync method:
68+
attrs["run_sync"] = async_to_sync(run_method)
6669
return type.__new__(meta, name, bases, attrs)
6770

6871

src/neo4j_graphrag/experimental/pipeline/pipeline.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from timeit import default_timer
2525
from typing import Any, AsyncGenerator, Optional
2626

27+
from neo4j_graphrag.utils import async_to_sync
28+
2729
try:
2830
import pygraphviz as pgv
2931
except ImportError:
@@ -105,6 +107,8 @@ async def run(self, inputs: dict[str, Any]) -> RunResult | None:
105107
logger.debug(f"TASK RESULT {self.name=} {res=}")
106108
return res
107109

110+
run_sync = async_to_sync(run)
111+
108112

109113
class Orchestrator:
110114
"""Orchestrate a pipeline.
@@ -618,3 +622,5 @@ async def run(self, data: dict[str, Any]) -> PipelineResult:
618622
run_id=orchestrator.run_id,
619623
result=await self.final_results.get(orchestrator.run_id),
620624
)
625+
626+
run_sync = async_to_sync(run)

src/neo4j_graphrag/utils.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
import inspect
18+
from functools import wraps
1819
from typing import Any, Optional, Union
1920
import asyncio
2021
import concurrent.futures
@@ -48,18 +49,8 @@ def run_sync(function, *args, **kwargs):
4849
return return_value
4950

5051

51-
52-
if __name__ == "__main__":
53-
async def async_run(char: str, repeat: int = 2) -> str:
54-
await asyncio.sleep(5)
55-
return char * repeat
56-
57-
async def async_run_multiple(char, n=10):
58-
return await asyncio.gather(*[
59-
async_run(char)
60-
for _ in range(n)
61-
])
62-
63-
print(
64-
run_sync(async_run_multiple, "abc")
65-
)
52+
def async_to_sync(func):
53+
@wraps(func)
54+
def wrapper(*args, **kwargs):
55+
return run_sync(func, *args, **kwargs)
56+
return wrapper

0 commit comments

Comments
 (0)