Skip to content

Commit 46ec9f8

Browse files
committed
[owl] Dynamically determined row/col concurrency (#804)
* rely solely on the configurable cell-based limit when determining generative table concurrency * row/col concurrency determined dynamically before generation. * remove the dedicated row/column batch settings and update the executor and planner tests accordingly
1 parent 8575711 commit 46ec9f8

File tree

5 files changed

+631
-31
lines changed

5 files changed

+631
-31
lines changed

.env.example

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,7 @@ JAMAI_API_BASE=http://localhost:6969/api
4848
# Configuration
4949
OWL_PORT=6969
5050
OWL_WORKERS=3
51-
OWL_CONCURRENT_ROWS_BATCH_SIZE=5
52-
OWL_CONCURRENT_COLS_BATCH_SIZE=5
51+
OWL_CONCURRENT_CELL_BATCH_SIZE=15
5352
OWL_MAX_WRITE_BATCH_SIZE=1000
5453
PB_MAX_CLIENT_CONN=500
5554
PB_MAX_CLIENT_CONN=80

services/api/src/owl/configs/oss.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,7 @@ class EnvConfig(BaseSettings):
5555
# Starling database configs
5656
flush_clickhouse_buffer_sec: int = 60
5757
# Generative Table configs
58-
concurrent_rows_batch_size: int = 3
59-
concurrent_cols_batch_size: int = 5
58+
concurrent_cell_batch_size: int = 15
6059
max_write_batch_size: int = 100
6160
max_file_cache_size: int = 20
6261
# PDF Loader configs

services/api/src/owl/db/gen_executor.py

Lines changed: 169 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import base64
22
import re
33
from asyncio import Queue, TaskGroup
4+
from collections import defaultdict, deque
45
from os.path import basename, splitext
56
from time import perf_counter, time
6-
from typing import Any, AsyncGenerator, Literal
7+
from typing import Any, AsyncGenerator, Literal, Sequence
78

89
import numpy as np
910
from async_lru import alru_cache
@@ -13,7 +14,11 @@
1314
from pydantic import BaseModel
1415

1516
from owl.configs import ENV_CONFIG
16-
from owl.db.gen_table import GenerativeTableCore, KnowledgeTable
17+
from owl.db.gen_table import (
18+
ColumnMetadata,
19+
GenerativeTableCore,
20+
KnowledgeTable,
21+
)
1722
from owl.docparse import GeneralDocLoader
1823
from owl.types import (
1924
AUDIO_FILE_EXTENSIONS,
@@ -56,6 +61,7 @@
5661
from owl.utils import mask_string, uuid7_draft2_str
5762
from owl.utils.billing import BillingManager
5863
from owl.utils.code import code_executor
64+
from owl.utils.concurrency import determine_concurrent_batches
5965
from owl.utils.exceptions import (
6066
BadInputError,
6167
JamaiException,
@@ -98,6 +104,8 @@ def __init__(
98104
organization: OrganizationRead,
99105
project: ProjectRead,
100106
body: MultiRowAddRequest | MultiRowRegenRequest | RowAdd | RowRegen,
107+
col_batch_size: int,
108+
row_batch_size: int,
101109
) -> None:
102110
self.request = request
103111
self._request_id: str = request.state.id
@@ -110,12 +118,12 @@ def __init__(
110118
raise ValueError(f"{body.table_id=} but {table.table_id=}")
111119
self.body = body
112120
self._stream = self.body.stream
113-
# Determine batch sizes
121+
114122
self._multi_turn = (
115123
sum(getattr(col.gen_config, "multi_turn", False) for col in table.column_metadata) > 0
116124
)
117-
self._col_batch_size = ENV_CONFIG.concurrent_cols_batch_size if body.concurrent else 1
118-
self._row_batch_size = 1 if self._multi_turn else ENV_CONFIG.concurrent_rows_batch_size
125+
self._col_batch_size = col_batch_size
126+
self._row_batch_size = row_batch_size
119127

120128
@classmethod
121129
def _log(cls, msg: str, level: str = "INFO", request_id: str = "", **kwargs):
@@ -144,6 +152,104 @@ def _log_item(x: Any) -> str:
144152
else:
145153
return f"type={type(x)}"
146154

155+
@staticmethod
156+
def _parse_prompt_dependencies(prompt: str | None) -> list[str]:
157+
if not prompt:
158+
return []
159+
return re.findall(GEN_CONFIG_VAR_PATTERN, prompt)
160+
161+
def _extract_upstream_columns(self, prompt: str | None) -> list[str]:
162+
return self._parse_prompt_dependencies(prompt)
163+
164+
def _extract_all_upstream_columns(self, output_column_name: str) -> list[str]:
165+
return self._extract_all_upstream_columns_from(
166+
self.table.column_metadata, output_column_name
167+
)
168+
169+
@staticmethod
170+
def _extract_all_upstream_columns_from(
171+
columns: Sequence[ColumnMetadata], output_column_name: str
172+
) -> list[str]:
173+
try:
174+
idx = next(i for i, c in enumerate(columns) if c.column_id == output_column_name)
175+
except StopIteration:
176+
return []
177+
return [
178+
c.column_id
179+
for c in columns[:idx]
180+
if not (c.is_info_column or c.is_state_column or c.is_vector_column)
181+
]
182+
183+
@classmethod
184+
def _collect_column_dependencies(
185+
cls,
186+
column: ColumnMetadata,
187+
*,
188+
columns: Sequence[ColumnMetadata],
189+
output_column_ids: set[str],
190+
) -> list[str]:
191+
gen_config = column.gen_config
192+
if gen_config is None:
193+
return []
194+
195+
dependencies: list[str]
196+
if isinstance(gen_config, PythonGenConfig):
197+
dependencies = cls._extract_all_upstream_columns_from(columns, column.column_id)
198+
elif isinstance(gen_config, (CodeGenConfig, EmbedGenConfig)):
199+
dependencies = [gen_config.source_column]
200+
elif isinstance(gen_config, LLMGenConfig):
201+
dependencies = cls._parse_prompt_dependencies(gen_config.prompt)
202+
else:
203+
dependencies = []
204+
205+
return [dep for dep in dependencies if dep in output_column_ids]
206+
207+
@classmethod
208+
def build_dependency_levels(cls, columns: Sequence[ColumnMetadata]) -> list[list[str]]:
209+
output_columns = [col for col in columns if col.is_output_column]
210+
if not output_columns:
211+
return []
212+
213+
adjacency: dict[str, list[str]] = defaultdict(list)
214+
in_degree: dict[str, int] = defaultdict(int)
215+
output_column_ids = {col.column_id for col in output_columns}
216+
217+
for column in output_columns:
218+
in_degree[column.column_id] = 0
219+
220+
for column in output_columns:
221+
dependencies = cls._collect_column_dependencies(
222+
column,
223+
columns=columns,
224+
output_column_ids=output_column_ids,
225+
)
226+
for dep in dependencies:
227+
adjacency[dep].append(column.column_id)
228+
in_degree[column.column_id] += 1
229+
230+
queue = deque([col.column_id for col in output_columns if in_degree[col.column_id] == 0])
231+
levels: list[list[str]] = []
232+
233+
while queue:
234+
current_level = list(queue)
235+
levels.append(current_level)
236+
queue = deque()
237+
238+
for col_id in current_level:
239+
for dependent in adjacency[col_id]:
240+
in_degree[dependent] -= 1
241+
if in_degree[dependent] == 0:
242+
queue.append(dependent)
243+
244+
return levels
245+
246+
@classmethod
247+
def get_max_concurrent_columns(cls, columns: Sequence[ColumnMetadata]) -> int:
248+
dependency_levels = cls.build_dependency_levels(columns)
249+
if not dependency_levels:
250+
return 1
251+
return max(len(level) for level in dependency_levels)
252+
147253

148254
class MultiRowGenExecutor(_Executor):
149255
def __init__(
@@ -155,8 +261,48 @@ def __init__(
155261
project: ProjectRead,
156262
body: MultiRowAddRequest | MultiRowRegenRequest,
157263
) -> None:
158-
_kwargs = dict(request=request, table=table, organization=organization, project=project)
159-
super().__init__(body=body, **_kwargs)
264+
concurrent = body.concurrent
265+
multi_turn = (
266+
sum(getattr(col.gen_config, "multi_turn", False) for col in table.column_metadata) > 0
267+
)
268+
max_concurrent_cols = self.get_max_concurrent_columns(table.column_metadata)
269+
col_batch_size, row_batch_size = determine_concurrent_batches(
270+
columns=table.column_metadata,
271+
body=body,
272+
concurrent=concurrent,
273+
multi_turn=multi_turn,
274+
cell_limit=ENV_CONFIG.concurrent_cell_batch_size,
275+
max_concurrent_cols=max_concurrent_cols,
276+
)
277+
278+
_context = dict(
279+
request=request,
280+
table=table,
281+
organization=organization,
282+
project=project,
283+
)
284+
super().__init__(
285+
body=body,
286+
col_batch_size=col_batch_size,
287+
row_batch_size=row_batch_size,
288+
**_context,
289+
)
290+
self.log(
291+
(
292+
"Concurrency plan determined: "
293+
f"columns={col_batch_size}, rows={row_batch_size}, multi_turn={multi_turn}, concurrent={concurrent}"
294+
),
295+
level="DEBUG",
296+
columns=col_batch_size,
297+
rows=row_batch_size,
298+
multi_turn=multi_turn,
299+
concurrent=concurrent,
300+
)
301+
302+
# Store pre-computed sizes for child executors
303+
self._col_batch_size = col_batch_size
304+
self._row_batch_size = row_batch_size
305+
160306
# Executors
161307
if isinstance(body, MultiRowAddRequest):
162308
self._is_regen = False
@@ -168,7 +314,9 @@ def __init__(
168314
stream=body.stream,
169315
concurrent=body.concurrent,
170316
),
171-
**_kwargs,
317+
col_batch_size=self._col_batch_size,
318+
row_batch_size=self._row_batch_size,
319+
**_context,
172320
)
173321
for row_data in body.data
174322
]
@@ -184,7 +332,9 @@ def __init__(
184332
stream=body.stream,
185333
concurrent=body.concurrent,
186334
),
187-
**_kwargs,
335+
col_batch_size=self._col_batch_size,
336+
row_batch_size=self._row_batch_size,
337+
**_context,
188338
)
189339
for row_id in body.row_ids
190340
]
@@ -303,10 +453,19 @@ def __init__(
303453
organization: OrganizationRead,
304454
project: ProjectRead,
305455
body: RowAdd | RowRegen,
456+
col_batch_size: int,
457+
row_batch_size: int,
306458
) -> None:
307459
super().__init__(
308-
request=request, table=table, organization=organization, project=project, body=body
460+
request=request,
461+
table=table,
462+
organization=organization,
463+
project=project,
464+
body=body,
465+
col_batch_size=col_batch_size,
466+
row_batch_size=row_batch_size,
309467
)
468+
310469
# Engines
311470
self.lm = LMEngine(organization=organization, project=project, request=request)
312471
# Tasks
@@ -1032,23 +1191,6 @@ async def _load_files(self, message: ChatThreadEntry) -> ChatThreadEntry | ChatE
10321191
# logger.warning(f"{message=}")
10331192
return message
10341193

1035-
def _extract_upstream_columns(self, prompt: str) -> list[str]:
1036-
col_ids = re.findall(GEN_CONFIG_VAR_PATTERN, prompt)
1037-
# return the content inside ${...}
1038-
return col_ids
1039-
1040-
def _extract_all_upstream_columns(self, output_column_name: str) -> list[str]:
1041-
cols = self.table.column_metadata
1042-
try:
1043-
idx = next(i for i, c in enumerate(cols) if c.column_id == output_column_name)
1044-
except StopIteration:
1045-
return []
1046-
return [
1047-
c.column_id
1048-
for c in cols[:idx]
1049-
if not (c.is_info_column or c.is_state_column or c.is_vector_column)
1050-
]
1051-
10521194
def _check_upstream_error(self, upstream_cols: list[str]) -> None:
10531195
if not isinstance(upstream_cols, list):
10541196
raise TypeError(f"`upstream_cols` must be a list, got: {type(upstream_cols)}")

0 commit comments

Comments
 (0)