Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 45 additions & 22 deletions python/cocoindex/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,17 +120,23 @@ def _transform_helper(
else:
raise ValueError("transform() can only be called on a CocoIndex function")

return _create_data_slice(
flow_builder_state,
lambda target_scope, name: flow_builder_state.engine_flow_builder.transform(
def _create_data_slice_inner(
target_scope: _engine.DataScopeRef | None, name: str | None
) -> _engine.DataSlice:
result = flow_builder_state.engine_flow_builder.transform(
kind,
dump_engine_object(spec),
transform_args,
target_scope,
flow_builder_state.field_name_builder.build_name(
name, prefix=_to_snake_case(_spec_kind(fn_spec)) + "_"
),
),
)
return result

return _create_data_slice(
flow_builder_state,
_create_data_slice_inner,
name,
)

Expand Down Expand Up @@ -166,6 +172,7 @@ def __init__(
def engine_data_slice(self) -> _engine.DataSlice:
"""
Get the internal DataSlice.
This can be blocking.
"""
if self._lazy_lock is None:
if self._data_slice is None:
Expand All @@ -179,6 +186,13 @@ def engine_data_slice(self) -> _engine.DataSlice:
self._data_slice = self._data_slice_creator(None)
return self._data_slice

async def engine_data_slice_async(self) -> _engine.DataSlice:
"""
Get the internal DataSlice.
This can be blocking.
"""
return await asyncio.to_thread(lambda: self.engine_data_slice)

def attach_to_scope(self, scope: _engine.DataScopeRef, field_name: str) -> None:
"""
Attach the current data slice (if not yet attached) to the given scope.
Expand Down Expand Up @@ -795,9 +809,8 @@ async def setup_async(self, report_to_stdout: bool = False) -> None:
"""
Setup persistent backends of the flow. The async version.
"""
await make_setup_bundle([self]).describe_and_apply_async(
report_to_stdout=report_to_stdout
)
bundle = await make_setup_bundle_async([self])
await bundle.describe_and_apply_async(report_to_stdout=report_to_stdout)

def drop(self, report_to_stdout: bool = False) -> None:
"""
Expand All @@ -814,9 +827,8 @@ async def drop_async(self, report_to_stdout: bool = False) -> None:
"""
Drop persistent backends of the flow. The async version.
"""
await make_drop_bundle([self]).describe_and_apply_async(
report_to_stdout=report_to_stdout
)
bundle = await make_drop_bundle_async([self])
await bundle.describe_and_apply_async(report_to_stdout=report_to_stdout)

def close(self) -> None:
"""
Expand Down Expand Up @@ -1071,19 +1083,16 @@ async def _build_flow_info_async(self) -> TransformFlowInfo:
_DataSliceState(flow_builder_state, engine_ds)
)

output = self._flow_fn(**kwargs)
flow_builder_state.engine_flow_builder.set_direct_output(
_data_slice_state(output).engine_data_slice
)
output = await asyncio.to_thread(lambda: self._flow_fn(**kwargs))
output_data_slice = await _data_slice_state(output).engine_data_slice_async()

flow_builder_state.engine_flow_builder.set_direct_output(output_data_slice)
engine_flow = (
await flow_builder_state.engine_flow_builder.build_transient_flow_async(
execution_context.event_loop
)
)

engine_return_type = (
_data_slice_state(output).engine_data_slice.data_type().schema()
)
engine_return_type = output_data_slice.data_type().schema()
python_return_type: type[T] | None = _get_data_slice_annotation_type(
inspect.signature(self._flow_fn).return_annotation
)
Expand Down Expand Up @@ -1142,28 +1151,42 @@ def _transform_flow_wrapper(fn: Callable[..., DataSlice[T]]) -> TransformFlow[T]
return _transform_flow_wrapper


def make_setup_bundle(flow_iter: Iterable[Flow]) -> SetupChangeBundle:
async def make_setup_bundle_async(flow_iter: Iterable[Flow]) -> SetupChangeBundle:
"""
Make a bundle to setup flows with the given names.
"""
full_names = []
for fl in flow_iter:
fl.internal_flow()
await fl.internal_flow_async()
full_names.append(fl.full_name)
return SetupChangeBundle(_engine.make_setup_bundle(full_names))


def make_drop_bundle(flow_iter: Iterable[Flow]) -> SetupChangeBundle:
def make_setup_bundle(flow_iter: Iterable[Flow]) -> SetupChangeBundle:
"""
Make a bundle to setup flows with the given names.
"""
return execution_context.run(make_setup_bundle_async(flow_iter))


async def make_drop_bundle_async(flow_iter: Iterable[Flow]) -> SetupChangeBundle:
"""
Make a bundle to drop flows with the given names.
"""
full_names = []
for fl in flow_iter:
fl.internal_flow()
await fl.internal_flow_async()
full_names.append(fl.full_name)
return SetupChangeBundle(_engine.make_drop_bundle(full_names))


def make_drop_bundle(flow_iter: Iterable[Flow]) -> SetupChangeBundle:
"""
Make a bundle to drop flows with the given names.
"""
return execution_context.run(make_drop_bundle_async(flow_iter))


def setup_all_flows(report_to_stdout: bool = False) -> None:
"""
Setup all flows registered in the current process.
Expand Down
31 changes: 31 additions & 0 deletions python/cocoindex/tests/test_transform_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,25 @@ def __call__(self, text: str) -> str:
return f"{text}{self.spec.suffix}"


class GpuAppendSuffixWithAnalyzePrepare(cocoindex.op.FunctionSpec):
suffix: str


@cocoindex.op.executor_class(gpu=True)
class GpuAppendSuffixWithAnalyzePrepareExecutor:
spec: GpuAppendSuffixWithAnalyzePrepare
suffix: str

def analyze(self) -> Any:
return str

def prepare(self) -> None:
self.suffix = self.spec.suffix

def __call__(self, text: str) -> str:
return f"{text}{self.suffix}"


def test_gpu_function() -> None:
@cocoindex.transform_flow()
def transform_flow(text: cocoindex.DataSlice[str]) -> cocoindex.DataSlice[str]:
Expand All @@ -174,3 +193,15 @@ def transform_flow(text: cocoindex.DataSlice[str]) -> cocoindex.DataSlice[str]:
result = transform_flow.eval("Hello")
expected = "Hello world!"
assert result == expected, f"Expected {expected}, got {result}"

@cocoindex.transform_flow()
def transform_flow_with_analyze_prepare(
text: cocoindex.DataSlice[str],
) -> cocoindex.DataSlice[str]:
return text.transform(gpu_append_world).transform(
GpuAppendSuffixWithAnalyzePrepare(suffix="!!")
)

result = transform_flow_with_analyze_prepare.eval("Hello")
expected = "Hello world!!"
assert result == expected, f"Expected {expected}, got {result}"
Loading