diff --git a/python/cocoindex/flow.py b/python/cocoindex/flow.py index 7a309b7ff..4171f7243 100644 --- a/python/cocoindex/flow.py +++ b/python/cocoindex/flow.py @@ -120,9 +120,10 @@ def _transform_helper( else: raise ValueError("transform() can only be called on a CocoIndex function") - return _create_data_slice( - flow_builder_state, - lambda target_scope, name: flow_builder_state.engine_flow_builder.transform( + def _create_data_slice_inner( + target_scope: _engine.DataScopeRef | None, name: str | None + ) -> _engine.DataSlice: + result = flow_builder_state.engine_flow_builder.transform( kind, dump_engine_object(spec), transform_args, @@ -130,7 +131,12 @@ def _transform_helper( flow_builder_state.field_name_builder.build_name( name, prefix=_to_snake_case(_spec_kind(fn_spec)) + "_" ), - ), + ) + return result + + return _create_data_slice( + flow_builder_state, + _create_data_slice_inner, name, ) @@ -166,6 +172,7 @@ def __init__( def engine_data_slice(self) -> _engine.DataSlice: """ Get the internal DataSlice. + This can be blocking. """ if self._lazy_lock is None: if self._data_slice is None: @@ -179,6 +186,13 @@ def engine_data_slice(self) -> _engine.DataSlice: self._data_slice = self._data_slice_creator(None) return self._data_slice + async def engine_data_slice_async(self) -> _engine.DataSlice: + """ + Get the internal DataSlice. + This can be blocking. + """ + return await asyncio.to_thread(lambda: self.engine_data_slice) + def attach_to_scope(self, scope: _engine.DataScopeRef, field_name: str) -> None: """ Attach the current data slice (if not yet attached) to the given scope. @@ -795,9 +809,8 @@ async def setup_async(self, report_to_stdout: bool = False) -> None: """ Setup persistent backends of the flow. The async version. """ - await make_setup_bundle([self]).describe_and_apply_async( - report_to_stdout=report_to_stdout - ) + bundle = await make_setup_bundle_async([self]) + await bundle.describe_and_apply_async(report_to_stdout=report_to_stdout) def drop(self, report_to_stdout: bool = False) -> None: """ @@ -814,9 +827,8 @@ async def drop_async(self, report_to_stdout: bool = False) -> None: """ Drop persistent backends of the flow. The async version. """ - await make_drop_bundle([self]).describe_and_apply_async( - report_to_stdout=report_to_stdout - ) + bundle = await make_drop_bundle_async([self]) + await bundle.describe_and_apply_async(report_to_stdout=report_to_stdout) def close(self) -> None: """ @@ -1071,19 +1083,16 @@ async def _build_flow_info_async(self) -> TransformFlowInfo: _DataSliceState(flow_builder_state, engine_ds) ) - output = self._flow_fn(**kwargs) - flow_builder_state.engine_flow_builder.set_direct_output( - _data_slice_state(output).engine_data_slice - ) + output = await asyncio.to_thread(lambda: self._flow_fn(**kwargs)) + output_data_slice = await _data_slice_state(output).engine_data_slice_async() + + flow_builder_state.engine_flow_builder.set_direct_output(output_data_slice) engine_flow = ( await flow_builder_state.engine_flow_builder.build_transient_flow_async( execution_context.event_loop ) ) - - engine_return_type = ( - _data_slice_state(output).engine_data_slice.data_type().schema() - ) + engine_return_type = output_data_slice.data_type().schema() python_return_type: type[T] | None = _get_data_slice_annotation_type( inspect.signature(self._flow_fn).return_annotation ) @@ -1142,28 +1151,42 @@ def _transform_flow_wrapper(fn: Callable[..., DataSlice[T]]) -> TransformFlow[T] return _transform_flow_wrapper -def make_setup_bundle(flow_iter: Iterable[Flow]) -> SetupChangeBundle: +async def make_setup_bundle_async(flow_iter: Iterable[Flow]) -> SetupChangeBundle: """ Make a bundle to setup flows with the given names. """ full_names = [] for fl in flow_iter: - fl.internal_flow() + await fl.internal_flow_async() full_names.append(fl.full_name) return SetupChangeBundle(_engine.make_setup_bundle(full_names)) -def make_drop_bundle(flow_iter: Iterable[Flow]) -> SetupChangeBundle: +def make_setup_bundle(flow_iter: Iterable[Flow]) -> SetupChangeBundle: + """ + Make a bundle to setup flows with the given names. + """ + return execution_context.run(make_setup_bundle_async(flow_iter)) + + +async def make_drop_bundle_async(flow_iter: Iterable[Flow]) -> SetupChangeBundle: """ Make a bundle to drop flows with the given names. """ full_names = [] for fl in flow_iter: - fl.internal_flow() + await fl.internal_flow_async() full_names.append(fl.full_name) return SetupChangeBundle(_engine.make_drop_bundle(full_names)) +def make_drop_bundle(flow_iter: Iterable[Flow]) -> SetupChangeBundle: + """ + Make a bundle to drop flows with the given names. + """ + return execution_context.run(make_drop_bundle_async(flow_iter)) + + def setup_all_flows(report_to_stdout: bool = False) -> None: """ Setup all flows registered in the current process. diff --git a/python/cocoindex/tests/test_transform_flow.py b/python/cocoindex/tests/test_transform_flow.py index 38b6a3a47..2d9d274b0 100644 --- a/python/cocoindex/tests/test_transform_flow.py +++ b/python/cocoindex/tests/test_transform_flow.py @@ -166,6 +166,25 @@ def __call__(self, text: str) -> str: return f"{text}{self.spec.suffix}" +class GpuAppendSuffixWithAnalyzePrepare(cocoindex.op.FunctionSpec): + suffix: str + + +@cocoindex.op.executor_class(gpu=True) +class GpuAppendSuffixWithAnalyzePrepareExecutor: + spec: GpuAppendSuffixWithAnalyzePrepare + suffix: str + + def analyze(self) -> Any: + return str + + def prepare(self) -> None: + self.suffix = self.spec.suffix + + def __call__(self, text: str) -> str: + return f"{text}{self.suffix}" + + def test_gpu_function() -> None: @cocoindex.transform_flow() def transform_flow(text: cocoindex.DataSlice[str]) -> cocoindex.DataSlice[str]: @@ -174,3 +193,15 @@ def transform_flow(text: cocoindex.DataSlice[str]) -> cocoindex.DataSlice[str]: result = transform_flow.eval("Hello") expected = "Hello world!" assert result == expected, f"Expected {expected}, got {result}" + + @cocoindex.transform_flow() + def transform_flow_with_analyze_prepare( + text: cocoindex.DataSlice[str], + ) -> cocoindex.DataSlice[str]: + return text.transform(gpu_append_world).transform( + GpuAppendSuffixWithAnalyzePrepare(suffix="!!") + ) + + result = transform_flow_with_analyze_prepare.eval("Hello") + expected = "Hello world!!" + assert result == expected, f"Expected {expected}, got {result}"