Skip to content

Commit 9213ea9

Browse files
authored
refactor(python-sdk): make async APIs clear (#465)
* refactor(python-sdk): make async APIs clear * refactor: reuse sync method
1 parent 6fb0f65 commit 9213ea9

File tree

4 files changed

+48
-31
lines changed

4 files changed

+48
-31
lines changed

examples/gdrive_text_embedding/main.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ def gdrive_text_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope:
5353
default_similarity_metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY)
5454

5555
@cocoindex.main_fn()
56-
async def _run():
56+
def _run():
5757
# Use a `FlowLiveUpdater` to keep the flow data updated.
58-
async with cocoindex.FlowLiveUpdater(gdrive_text_embedding_flow):
58+
with cocoindex.FlowLiveUpdater(gdrive_text_embedding_flow):
5959
# Run queries in a loop to demonstrate the query capabilities.
6060
while True:
6161
try:
@@ -74,4 +74,4 @@ async def _run():
7474

7575
if __name__ == "__main__":
7676
load_dotenv(override=True)
77-
asyncio.run(_run())
77+
_run()

python/cocoindex/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from . import functions, query, sources, storages, cli
55
from .flow import FlowBuilder, DataScope, DataSlice, Flow, flow_def
66
from .flow import EvaluateAndDumpOptions, GeneratedField
7-
from .flow import update_all_flows, FlowLiveUpdater, FlowLiveUpdaterOptions
7+
from .flow import update_all_flows_async, FlowLiveUpdater, FlowLiveUpdaterOptions
88
from .llm import LlmSpec, LlmApiType
99
from .index import VectorSimilarityMetric, VectorIndexDef, IndexOptions
1010
from .auth_registry import AuthEntryReference, add_auth_entry, ref_auth_entry

python/cocoindex/cli.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import asyncio
21
import click
32
import datetime
43

@@ -7,7 +6,6 @@
76

87
from . import flow, lib, setting
98
from .setup import sync_setup, drop_setup, flow_names_with_setup, apply_setup_changes
10-
from .runtime import execution_context
119

1210
@click.group()
1311
def cli():
@@ -136,13 +134,12 @@ def update(flow_name: str | None, live: bool, quiet: bool):
136134
Update the index to reflect the latest data from data sources.
137135
"""
138136
options = flow.FlowLiveUpdaterOptions(live_mode=live, print_stats=not quiet)
139-
async def _update():
140-
if flow_name is None:
141-
await flow.update_all_flows(options)
142-
else:
143-
updater = await flow.FlowLiveUpdater.create(_flow_by_name(flow_name), options)
144-
await updater.wait()
145-
execution_context.run(_update())
137+
if flow_name is None:
138+
return flow.update_all_flows(options)
139+
else:
140+
updater = flow.FlowLiveUpdater(_flow_by_name(flow_name), options)
141+
updater.wait()
142+
return updater.update_stats()
146143

147144
@cli.command()
148145
@click.argument("flow_name", type=str, required=False)
@@ -217,7 +214,7 @@ def server(address: str | None, live_update: bool, quiet: bool, cors_origin: str
217214

218215
if live_update:
219216
options = flow.FlowLiveUpdaterOptions(live_mode=True, print_stats=not quiet)
220-
execution_context.run(flow.update_all_flows(options))
217+
flow.update_all_flows(options)
221218
if COCOINDEX_HOST in cors_origins:
222219
click.echo(f"Open CocoInsight at: {COCOINDEX_HOST}/cocoinsight")
223220
input("Press Enter to stop...")

python/cocoindex/flow.py

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import re
99
import inspect
1010
import datetime
11-
import json
1211

1312
from typing import Any, Callable, Sequence, TypeVar
1413
from threading import Lock
@@ -394,12 +393,13 @@ def __init__(self, arg: Flow | _engine.FlowLiveUpdater, options: FlowLiveUpdater
394393
arg.internal_flow(), dump_engine_object(options or FlowLiveUpdaterOptions())))
395394

396395
@staticmethod
397-
async def create(fl: Flow, options: FlowLiveUpdaterOptions | None = None) -> FlowLiveUpdater:
396+
async def create_async(fl: Flow, options: FlowLiveUpdaterOptions | None = None) -> FlowLiveUpdater:
398397
"""
399398
Create a live updater for a flow.
399+
Similar to the constructor, but for async usage.
400400
"""
401401
engine_live_updater = await _engine.FlowLiveUpdater.create(
402-
await fl.ainternal_flow(),
402+
await fl.internal_flow_async(),
403403
dump_engine_object(options or FlowLiveUpdaterOptions()))
404404
return FlowLiveUpdater(engine_live_updater)
405405

@@ -408,21 +408,28 @@ def __enter__(self) -> FlowLiveUpdater:
408408

409409
def __exit__(self, exc_type, exc_value, traceback):
410410
self.abort()
411-
execution_context.run(self.wait())
411+
self.wait()
412412

413413
async def __aenter__(self) -> FlowLiveUpdater:
414414
return self
415415

416416
async def __aexit__(self, exc_type, exc_value, traceback):
417417
self.abort()
418-
await self.wait()
418+
await self.wait_async()
419419

420-
async def wait(self) -> None:
420+
def wait(self) -> None:
421421
"""
422422
Wait for the live updater to finish.
423423
"""
424+
execution_context.run(self.wait_async())
425+
426+
async def wait_async(self) -> None:
427+
"""
428+
Wait for the live updater to finish. Async version.
429+
"""
424430
await self._engine_live_updater.wait()
425431

432+
426433
def abort(self) -> None:
427434
"""
428435
Abort the live updater.
@@ -500,13 +507,20 @@ def name(self) -> str:
500507
"""
501508
return self._lazy_engine_flow().name()
502509

503-
async def update(self) -> _engine.IndexUpdateInfo:
510+
def update(self) -> _engine.IndexUpdateInfo:
511+
"""
512+
Update the index defined by the flow.
513+
Once the function returns, the index is fresh up to the moment when the function is called.
514+
"""
515+
return execution_context.run(self.update_async())
516+
517+
async def update_async(self) -> _engine.IndexUpdateInfo:
504518
"""
505519
Update the index defined by the flow.
506-
Once the function returns, the indice is fresh up to the moment when the function is called.
520+
Once the function returns, the index is fresh up to the moment when the function is called.
507521
"""
508-
updater = await FlowLiveUpdater.create(self, FlowLiveUpdaterOptions(live_mode=False))
509-
await updater.wait()
522+
updater = await FlowLiveUpdater.create_async(self, FlowLiveUpdaterOptions(live_mode=False))
523+
await updater.wait_async()
510524
return updater.update_stats()
511525

512526
def evaluate_and_dump(self, options: EvaluateAndDumpOptions):
@@ -521,7 +535,7 @@ def internal_flow(self) -> _engine.Flow:
521535
"""
522536
return self._lazy_engine_flow()
523537

524-
async def ainternal_flow(self) -> _engine.Flow:
538+
async def internal_flow_async(self) -> _engine.Flow:
525539
"""
526540
Get the engine flow. The async version.
527541
"""
@@ -587,21 +601,27 @@ def ensure_all_flows_built() -> None:
587601
for fl in flows():
588602
fl.internal_flow()
589603

590-
async def aensure_all_flows_built() -> None:
604+
async def ensure_all_flows_built_async() -> None:
591605
"""
592606
Ensure all flows are built.
593607
"""
594608
for fl in flows():
595-
await fl.ainternal_flow()
609+
await fl.internal_flow_async()
610+
611+
def update_all_flows(options: FlowLiveUpdaterOptions) -> dict[str, _engine.IndexUpdateInfo]:
612+
"""
613+
Update all flows.
614+
"""
615+
return execution_context.run(update_all_flows_async(options))
596616

597-
async def update_all_flows(options: FlowLiveUpdaterOptions) -> dict[str, _engine.IndexUpdateInfo]:
617+
async def update_all_flows_async(options: FlowLiveUpdaterOptions) -> dict[str, _engine.IndexUpdateInfo]:
598618
"""
599619
Update all flows.
600620
"""
601-
await aensure_all_flows_built()
621+
await ensure_all_flows_built_async()
602622
async def _update_flow(fl: Flow) -> _engine.IndexUpdateInfo:
603-
updater = await FlowLiveUpdater.create(fl, options)
604-
await updater.wait()
623+
updater = await FlowLiveUpdater.create_async(fl, options)
624+
await updater.wait_async()
605625
return updater.update_stats()
606626
fls = flows()
607627
all_stats = await asyncio.gather(*(_update_flow(fl) for fl in fls))

0 commit comments

Comments
 (0)