11import base64
22import re
33from asyncio import Queue , TaskGroup
4+ from collections import defaultdict , deque
45from os .path import basename , splitext
56from time import perf_counter , time
6- from typing import Any , AsyncGenerator , Literal
7+ from typing import Any , AsyncGenerator , Literal , Sequence
78
89import numpy as np
910from async_lru import alru_cache
1314from pydantic import BaseModel
1415
1516from owl .configs import ENV_CONFIG
16- from owl .db .gen_table import GenerativeTableCore , KnowledgeTable
17+ from owl .db .gen_table import (
18+ ColumnMetadata ,
19+ GenerativeTableCore ,
20+ KnowledgeTable ,
21+ )
1722from owl .docparse import GeneralDocLoader
1823from owl .types import (
1924 AUDIO_FILE_EXTENSIONS ,
5661from owl .utils import mask_string , uuid7_draft2_str
5762from owl .utils .billing import BillingManager
5863from owl .utils .code import code_executor
64+ from owl .utils .concurrency import determine_concurrent_batches
5965from owl .utils .exceptions import (
6066 BadInputError ,
6167 JamaiException ,
@@ -98,6 +104,8 @@ def __init__(
98104 organization : OrganizationRead ,
99105 project : ProjectRead ,
100106 body : MultiRowAddRequest | MultiRowRegenRequest | RowAdd | RowRegen ,
107+ col_batch_size : int ,
108+ row_batch_size : int ,
101109 ) -> None :
102110 self .request = request
103111 self ._request_id : str = request .state .id
@@ -110,12 +118,12 @@ def __init__(
110118 raise ValueError (f"{ body .table_id = } but { table .table_id = } " )
111119 self .body = body
112120 self ._stream = self .body .stream
113- # Determine batch sizes
121+
114122 self ._multi_turn = (
115123 sum (getattr (col .gen_config , "multi_turn" , False ) for col in table .column_metadata ) > 0
116124 )
117- self ._col_batch_size = ENV_CONFIG . concurrent_cols_batch_size if body . concurrent else 1
118- self ._row_batch_size = 1 if self . _multi_turn else ENV_CONFIG . concurrent_rows_batch_size
125+ self ._col_batch_size = col_batch_size
126+ self ._row_batch_size = row_batch_size
119127
120128 @classmethod
121129 def _log (cls , msg : str , level : str = "INFO" , request_id : str = "" , ** kwargs ):
@@ -144,6 +152,104 @@ def _log_item(x: Any) -> str:
144152 else :
145153 return f"type={ type (x )} "
146154
155+ @staticmethod
156+ def _parse_prompt_dependencies (prompt : str | None ) -> list [str ]:
157+ if not prompt :
158+ return []
159+ return re .findall (GEN_CONFIG_VAR_PATTERN , prompt )
160+
161+ def _extract_upstream_columns (self , prompt : str | None ) -> list [str ]:
162+ return self ._parse_prompt_dependencies (prompt )
163+
164+ def _extract_all_upstream_columns (self , output_column_name : str ) -> list [str ]:
165+ return self ._extract_all_upstream_columns_from (
166+ self .table .column_metadata , output_column_name
167+ )
168+
169+ @staticmethod
170+ def _extract_all_upstream_columns_from (
171+ columns : Sequence [ColumnMetadata ], output_column_name : str
172+ ) -> list [str ]:
173+ try :
174+ idx = next (i for i , c in enumerate (columns ) if c .column_id == output_column_name )
175+ except StopIteration :
176+ return []
177+ return [
178+ c .column_id
179+ for c in columns [:idx ]
180+ if not (c .is_info_column or c .is_state_column or c .is_vector_column )
181+ ]
182+
183+ @classmethod
184+ def _collect_column_dependencies (
185+ cls ,
186+ column : ColumnMetadata ,
187+ * ,
188+ columns : Sequence [ColumnMetadata ],
189+ output_column_ids : set [str ],
190+ ) -> list [str ]:
191+ gen_config = column .gen_config
192+ if gen_config is None :
193+ return []
194+
195+ dependencies : list [str ]
196+ if isinstance (gen_config , PythonGenConfig ):
197+ dependencies = cls ._extract_all_upstream_columns_from (columns , column .column_id )
198+ elif isinstance (gen_config , (CodeGenConfig , EmbedGenConfig )):
199+ dependencies = [gen_config .source_column ]
200+ elif isinstance (gen_config , LLMGenConfig ):
201+ dependencies = cls ._parse_prompt_dependencies (gen_config .prompt )
202+ else :
203+ dependencies = []
204+
205+ return [dep for dep in dependencies if dep in output_column_ids ]
206+
207+ @classmethod
208+ def build_dependency_levels (cls , columns : Sequence [ColumnMetadata ]) -> list [list [str ]]:
209+ output_columns = [col for col in columns if col .is_output_column ]
210+ if not output_columns :
211+ return []
212+
213+ adjacency : dict [str , list [str ]] = defaultdict (list )
214+ in_degree : dict [str , int ] = defaultdict (int )
215+ output_column_ids = {col .column_id for col in output_columns }
216+
217+ for column in output_columns :
218+ in_degree [column .column_id ] = 0
219+
220+ for column in output_columns :
221+ dependencies = cls ._collect_column_dependencies (
222+ column ,
223+ columns = columns ,
224+ output_column_ids = output_column_ids ,
225+ )
226+ for dep in dependencies :
227+ adjacency [dep ].append (column .column_id )
228+ in_degree [column .column_id ] += 1
229+
230+ queue = deque ([col .column_id for col in output_columns if in_degree [col .column_id ] == 0 ])
231+ levels : list [list [str ]] = []
232+
233+ while queue :
234+ current_level = list (queue )
235+ levels .append (current_level )
236+ queue = deque ()
237+
238+ for col_id in current_level :
239+ for dependent in adjacency [col_id ]:
240+ in_degree [dependent ] -= 1
241+ if in_degree [dependent ] == 0 :
242+ queue .append (dependent )
243+
244+ return levels
245+
246+ @classmethod
247+ def get_max_concurrent_columns (cls , columns : Sequence [ColumnMetadata ]) -> int :
248+ dependency_levels = cls .build_dependency_levels (columns )
249+ if not dependency_levels :
250+ return 1
251+ return max (len (level ) for level in dependency_levels )
252+
147253
148254class MultiRowGenExecutor (_Executor ):
149255 def __init__ (
@@ -155,8 +261,48 @@ def __init__(
155261 project : ProjectRead ,
156262 body : MultiRowAddRequest | MultiRowRegenRequest ,
157263 ) -> None :
158- _kwargs = dict (request = request , table = table , organization = organization , project = project )
159- super ().__init__ (body = body , ** _kwargs )
264+ concurrent = body .concurrent
265+ multi_turn = (
266+ sum (getattr (col .gen_config , "multi_turn" , False ) for col in table .column_metadata ) > 0
267+ )
268+ max_concurrent_cols = self .get_max_concurrent_columns (table .column_metadata )
269+ col_batch_size , row_batch_size = determine_concurrent_batches (
270+ columns = table .column_metadata ,
271+ body = body ,
272+ concurrent = concurrent ,
273+ multi_turn = multi_turn ,
274+ cell_limit = ENV_CONFIG .concurrent_cell_batch_size ,
275+ max_concurrent_cols = max_concurrent_cols ,
276+ )
277+
278+ _context = dict (
279+ request = request ,
280+ table = table ,
281+ organization = organization ,
282+ project = project ,
283+ )
284+ super ().__init__ (
285+ body = body ,
286+ col_batch_size = col_batch_size ,
287+ row_batch_size = row_batch_size ,
288+ ** _context ,
289+ )
290+ self .log (
291+ (
292+ "Concurrency plan determined: "
293+ f"columns={ col_batch_size } , rows={ row_batch_size } , multi_turn={ multi_turn } , concurrent={ concurrent } "
294+ ),
295+ level = "DEBUG" ,
296+ columns = col_batch_size ,
297+ rows = row_batch_size ,
298+ multi_turn = multi_turn ,
299+ concurrent = concurrent ,
300+ )
301+
302+ # Store pre-computed sizes for child executors
303+ self ._col_batch_size = col_batch_size
304+ self ._row_batch_size = row_batch_size
305+
160306 # Executors
161307 if isinstance (body , MultiRowAddRequest ):
162308 self ._is_regen = False
@@ -168,7 +314,9 @@ def __init__(
168314 stream = body .stream ,
169315 concurrent = body .concurrent ,
170316 ),
171- ** _kwargs ,
317+ col_batch_size = self ._col_batch_size ,
318+ row_batch_size = self ._row_batch_size ,
319+ ** _context ,
172320 )
173321 for row_data in body .data
174322 ]
@@ -184,7 +332,9 @@ def __init__(
184332 stream = body .stream ,
185333 concurrent = body .concurrent ,
186334 ),
187- ** _kwargs ,
335+ col_batch_size = self ._col_batch_size ,
336+ row_batch_size = self ._row_batch_size ,
337+ ** _context ,
188338 )
189339 for row_id in body .row_ids
190340 ]
@@ -303,10 +453,19 @@ def __init__(
303453 organization : OrganizationRead ,
304454 project : ProjectRead ,
305455 body : RowAdd | RowRegen ,
456+ col_batch_size : int ,
457+ row_batch_size : int ,
306458 ) -> None :
307459 super ().__init__ (
308- request = request , table = table , organization = organization , project = project , body = body
460+ request = request ,
461+ table = table ,
462+ organization = organization ,
463+ project = project ,
464+ body = body ,
465+ col_batch_size = col_batch_size ,
466+ row_batch_size = row_batch_size ,
309467 )
468+
310469 # Engines
311470 self .lm = LMEngine (organization = organization , project = project , request = request )
312471 # Tasks
@@ -1032,23 +1191,6 @@ async def _load_files(self, message: ChatThreadEntry) -> ChatThreadEntry | ChatE
10321191 # logger.warning(f"{message=}")
10331192 return message
10341193
1035- def _extract_upstream_columns (self , prompt : str ) -> list [str ]:
1036- col_ids = re .findall (GEN_CONFIG_VAR_PATTERN , prompt )
1037- # return the content inside ${...}
1038- return col_ids
1039-
1040- def _extract_all_upstream_columns (self , output_column_name : str ) -> list [str ]:
1041- cols = self .table .column_metadata
1042- try :
1043- idx = next (i for i , c in enumerate (cols ) if c .column_id == output_column_name )
1044- except StopIteration :
1045- return []
1046- return [
1047- c .column_id
1048- for c in cols [:idx ]
1049- if not (c .is_info_column or c .is_state_column or c .is_vector_column )
1050- ]
1051-
10521194 def _check_upstream_error (self , upstream_cols : list [str ]) -> None :
10531195 if not isinstance (upstream_cols , list ):
10541196 raise TypeError (f"`upstream_cols` must be a list, got: { type (upstream_cols )} " )
0 commit comments