Skip to content

Commit 61d1ed6

Browse files
authored
refactor(sdk): make FlowLiveUpdater constructor like with a start() (#466)
1 parent 9213ea9 commit 61d1ed6

File tree

2 files changed

+35
-28
lines changed

2 files changed

+35
-28
lines changed

python/cocoindex/cli.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,9 @@ def update(flow_name: str | None, live: bool, quiet: bool):
137137
if flow_name is None:
138138
return flow.update_all_flows(options)
139139
else:
140-
updater = flow.FlowLiveUpdater(_flow_by_name(flow_name), options)
141-
updater.wait()
142-
return updater.update_stats()
140+
with flow.FlowLiveUpdater(_flow_by_name(flow_name), options) as updater:
141+
updater.wait()
142+
return updater.update_stats()
143143

144144
@cli.command()
145145
@click.argument("flow_name", type=str, required=False)

python/cocoindex/flow.py

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -383,40 +383,43 @@ class FlowLiveUpdater:
383383
"""
384384
A live updater for a flow.
385385
"""
386-
_engine_live_updater: _engine.FlowLiveUpdater
386+
_flow: Flow
387+
_options: FlowLiveUpdaterOptions
388+
_engine_live_updater: _engine.FlowLiveUpdater | None = None
387389

388-
def __init__(self, arg: Flow | _engine.FlowLiveUpdater, options: FlowLiveUpdaterOptions | None = None):
389-
if isinstance(arg, _engine.FlowLiveUpdater):
390-
self._engine_live_updater = arg
391-
else:
392-
self._engine_live_updater = execution_context.run(_engine.FlowLiveUpdater(
393-
arg.internal_flow(), dump_engine_object(options or FlowLiveUpdaterOptions())))
394-
395-
@staticmethod
396-
async def create_async(fl: Flow, options: FlowLiveUpdaterOptions | None = None) -> FlowLiveUpdater:
397-
"""
398-
Create a live updater for a flow.
399-
Similar to the constructor, but for async usage.
400-
"""
401-
engine_live_updater = await _engine.FlowLiveUpdater.create(
402-
await fl.internal_flow_async(),
403-
dump_engine_object(options or FlowLiveUpdaterOptions()))
404-
return FlowLiveUpdater(engine_live_updater)
390+
def __init__(self, fl: Flow, options: FlowLiveUpdaterOptions | None = None):
391+
self._flow = fl
392+
self._options = options or FlowLiveUpdaterOptions()
405393

406394
def __enter__(self) -> FlowLiveUpdater:
395+
self.start()
407396
return self
408397

409398
def __exit__(self, exc_type, exc_value, traceback):
410399
self.abort()
411400
self.wait()
412401

413402
async def __aenter__(self) -> FlowLiveUpdater:
403+
await self.start_async()
414404
return self
415405

416406
async def __aexit__(self, exc_type, exc_value, traceback):
417407
self.abort()
418408
await self.wait_async()
419409

410+
def start(self) -> None:
411+
"""
412+
Start the live updater.
413+
"""
414+
execution_context.run(self.start_async())
415+
416+
async def start_async(self) -> None:
417+
"""
418+
Start the live updater.
419+
"""
420+
self._engine_live_updater = await _engine.FlowLiveUpdater.create(
421+
await self._flow.internal_flow_async(), dump_engine_object(self._options))
422+
420423
def wait(self) -> None:
421424
"""
422425
Wait for the live updater to finish.
@@ -427,20 +430,24 @@ async def wait_async(self) -> None:
427430
"""
428431
Wait for the live updater to finish. Async version.
429432
"""
430-
await self._engine_live_updater.wait()
431-
433+
await self._get_engine_live_updater().wait()
432434

433435
def abort(self) -> None:
434436
"""
435437
Abort the live updater.
436438
"""
437-
self._engine_live_updater.abort()
439+
self._get_engine_live_updater().abort()
438440

439441
def update_stats(self) -> _engine.IndexUpdateInfo:
440442
"""
441443
Get the index update info.
442444
"""
443-
return self._engine_live_updater.index_update_info()
445+
return self._get_engine_live_updater().index_update_info()
446+
447+
def _get_engine_live_updater(self) -> _engine.FlowLiveUpdater:
448+
if self._engine_live_updater is None:
449+
raise RuntimeError("Live updater is not started")
450+
return self._engine_live_updater
444451

445452

446453
@dataclass
@@ -620,9 +627,9 @@ async def update_all_flows_async(options: FlowLiveUpdaterOptions) -> dict[str, _
620627
"""
621628
await ensure_all_flows_built_async()
622629
async def _update_flow(fl: Flow) -> _engine.IndexUpdateInfo:
623-
updater = await FlowLiveUpdater.create_async(fl, options)
624-
await updater.wait_async()
625-
return updater.update_stats()
630+
async with FlowLiveUpdater(fl, options) as updater:
631+
await updater.wait_async()
632+
return updater.update_stats()
626633
fls = flows()
627634
all_stats = await asyncio.gather(*(_update_flow(fl) for fl in fls))
628635
return {fl.name: stats for fl, stats in zip(fls, all_stats)}

0 commit comments

Comments
 (0)