2222)
2323from datachain .query .queue import get_from_queue , put_into_queue
2424from datachain .query .udf import UdfInfo
25- from datachain .query .utils import get_query_id_column
2625from datachain .utils import batched , flatten , safe_closing
2726
2827if TYPE_CHECKING :
@@ -55,6 +54,9 @@ def udf_entrypoint() -> int:
5554 udf_info : UdfInfo = load (stdin .buffer )
5655
5756 query = udf_info ["query" ]
57+ if "sys__id" not in query .selected_columns :
58+ raise RuntimeError ("sys__id column is required in UDF query" )
59+
5860 batching = udf_info ["batching" ]
5961 is_generator = udf_info ["is_generator" ]
6062
@@ -65,15 +67,16 @@ def udf_entrypoint() -> int:
6567 wh_cls , wh_args , wh_kwargs = udf_info ["warehouse_clone_params" ]
6668 warehouse : AbstractWarehouse = wh_cls (* wh_args , ** wh_kwargs )
6769
68- id_col = get_query_id_column (query )
69-
7070 with contextlib .closing (
71- batching (warehouse .dataset_select_paginated , query , id_col = id_col )
71+ batching (
72+ warehouse .dataset_select_paginated ,
73+ query ,
74+ id_col = query .selected_columns .sys__id ,
75+ )
7276 ) as udf_inputs :
7377 try :
7478 UDFDispatcher (udf_info ).run_udf (
7579 udf_inputs ,
76- ids_only = id_col is not None ,
7780 download_cb = download_cb ,
7881 processed_cb = processed_cb ,
7982 generated_cb = generated_cb ,
@@ -147,10 +150,10 @@ def _create_worker(self) -> "UDFWorker":
147150 self .udf_fields ,
148151 )
149152
150- def _run_worker (self , ids_only : bool ) -> None :
153+ def _run_worker (self ) -> None :
151154 try :
152155 worker = self ._create_worker ()
153- worker .run (ids_only )
156+ worker .run ()
154157 except (Exception , KeyboardInterrupt ) as e :
155158 if self .done_queue :
156159 put_into_queue (
@@ -164,7 +167,6 @@ def _run_worker(self, ids_only: bool) -> None:
164167 def run_udf (
165168 self ,
166169 input_rows : Iterable ["RowsOutput" ],
167- ids_only : bool ,
168170 download_cb : Callback = DEFAULT_CALLBACK ,
169171 processed_cb : Callback = DEFAULT_CALLBACK ,
170172 generated_cb : Callback = DEFAULT_CALLBACK ,
@@ -178,9 +180,7 @@ def run_udf(
178180
179181 if n_workers == 1 :
180182 # no need to spawn worker processes if we are running in a single process
181- self .run_udf_single (
182- input_rows , ids_only , download_cb , processed_cb , generated_cb
183- )
183+ self .run_udf_single (input_rows , download_cb , processed_cb , generated_cb )
184184 else :
185185 if self .buffer_size < n_workers :
186186 raise RuntimeError (
@@ -189,13 +189,12 @@ def run_udf(
189189 )
190190
191191 self .run_udf_parallel (
192- n_workers , input_rows , ids_only , download_cb , processed_cb , generated_cb
192+ n_workers , input_rows , download_cb , processed_cb , generated_cb
193193 )
194194
195195 def run_udf_single (
196196 self ,
197197 input_rows : Iterable ["RowsOutput" ],
198- ids_only : bool ,
199198 download_cb : Callback = DEFAULT_CALLBACK ,
200199 processed_cb : Callback = DEFAULT_CALLBACK ,
201200 generated_cb : Callback = DEFAULT_CALLBACK ,
@@ -204,18 +203,15 @@ def run_udf_single(
204203 # Rebuild schemas in single process too for consistency (cheap, idempotent).
205204 ModelStore .rebuild_all ()
206205
207- if ids_only and not self .is_batching :
206+ if not self .is_batching :
208207 input_rows = flatten (input_rows )
209208
210209 def get_inputs () -> Iterable ["RowsOutput" ]:
211210 warehouse = self .catalog .warehouse .clone ()
212- if ids_only :
213- for ids in batched (input_rows , DEFAULT_BATCH_SIZE ):
214- yield from warehouse .dataset_rows_select_from_ids (
215- self .query , ids , self .is_batching
216- )
217- else :
218- yield from input_rows
211+ for ids in batched (input_rows , DEFAULT_BATCH_SIZE ):
212+ yield from warehouse .dataset_rows_select_from_ids (
213+ self .query , ids , self .is_batching
214+ )
219215
220216 prefetch = udf .prefetch
221217 with _get_cache (self .catalog .cache , prefetch , use_cache = self .cache ) as _cache :
@@ -249,7 +245,6 @@ def run_udf_parallel( # noqa: C901, PLR0912
249245 self ,
250246 n_workers : int ,
251247 input_rows : Iterable ["RowsOutput" ],
252- ids_only : bool ,
253248 download_cb : Callback = DEFAULT_CALLBACK ,
254249 processed_cb : Callback = DEFAULT_CALLBACK ,
255250 generated_cb : Callback = DEFAULT_CALLBACK ,
@@ -258,9 +253,7 @@ def run_udf_parallel( # noqa: C901, PLR0912
258253 self .done_queue = self .ctx .Queue ()
259254
260255 pool = [
261- self .ctx .Process (
262- name = f"Worker-UDF-{ i } " , target = self ._run_worker , args = [ids_only ]
263- )
256+ self .ctx .Process (name = f"Worker-UDF-{ i } " , target = self ._run_worker )
264257 for i in range (n_workers )
265258 ]
266259 for p in pool :
@@ -406,13 +399,13 @@ def __init__(
406399 self .processed_cb = ProcessedCallback ("processed" , self .done_queue )
407400 self .generated_cb = ProcessedCallback ("generated" , self .done_queue )
408401
409- def run (self , ids_only : bool ) -> None :
402+ def run (self ) -> None :
410403 prefetch = self .udf .prefetch
411404 with _get_cache (self .catalog .cache , prefetch , use_cache = self .cache ) as _cache :
412405 catalog = clone_catalog_with_cache (self .catalog , _cache )
413406 udf_results = self .udf .run (
414407 self .udf_fields ,
415- self .get_inputs (ids_only ),
408+ self .get_inputs (),
416409 catalog ,
417410 self .cache ,
418411 download_cb = self .download_cb ,
@@ -434,13 +427,10 @@ def notify_and_process(self, udf_results):
434427 put_into_queue (self .done_queue , {"status" : OK_STATUS })
435428 yield row
436429
437- def get_inputs (self , ids_only : bool ) -> Iterable ["RowsOutput" ]:
430+ def get_inputs (self ) -> Iterable ["RowsOutput" ]:
438431 warehouse = self .catalog .warehouse .clone ()
439432 while (batch := get_from_queue (self .task_queue )) != STOP_SIGNAL :
440- if ids_only :
441- for ids in batched (batch , DEFAULT_BATCH_SIZE ):
442- yield from warehouse .dataset_rows_select_from_ids (
443- self .query , ids , self .is_batching
444- )
445- else :
446- yield from batch
433+ for ids in batched (batch , DEFAULT_BATCH_SIZE ):
434+ yield from warehouse .dataset_rows_select_from_ids (
435+ self .query , ids , self .is_batching
436+ )
0 commit comments