Skip to content

Commit 4b3b9a2

Browse files
authored
feat(flow-control): add bytes-based flow control for source (#709)
* feat(flow-control): add bytes-based flow control for source * docs: fix broken links
1 parent 8eb10c9 commit 4b3b9a2

File tree

14 files changed

+217
-80
lines changed

14 files changed

+217
-80
lines changed

docs/docs/core/flow_def.mdx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ If nothing changed during the last refresh cycle, only list operations will be p
156156
You can pass the following arguments to `add_source()` to control the concurrency of the source operation:
157157

158158
* `max_inflight_rows`: the maximum number of concurrent inflight requests for the source operation.
159+
* `max_inflight_bytes`: the maximum number of concurrent inflight bytes for the source operation.
159160

160161
The default value can be specified by [`DefaultExecutionOptions`](/docs/core/settings#defaultexecutionoptions) or corresponding [environment variable](/docs/core/settings#list-of-environment-variables).
161162

docs/docs/core/settings.mdx

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,10 @@ If you use the Postgres database hosted by [Supabase](https://supabase.com/), pl
109109

110110
`DefaultExecutionOptions` is used to configure the default execution options for the flow. It has the following fields:
111111

112-
* `source_max_inflight_rows` (type: `int`, optional): The maximum number of concurrent inflight requests for source operations. This only provides an default value, and can be overridden by the `max_inflight_rows` argument passed to `FlowBuilder.add_source()` on per-source basis.
112+
* `source_max_inflight_rows` (type: `int`, optional): The maximum number of concurrent inflight requests for source operations.
113+
* `source_max_inflight_bytes` (type: `int`, optional): The maximum number of concurrent inflight bytes for source operations.
114+
115+
The options provide default values, and can be overridden by arguments passed to `FlowBuilder.add_source()` on per-source basis ([details](/docs/core/flow_def#concurrency-control)).
113116

114117
## List of Environment Variables
115118

@@ -122,3 +125,4 @@ This is the list of environment variables, each of which has a corresponding fie
122125
| `COCOINDEX_DATABASE_USER` | `database.user` | No |
123126
| `COCOINDEX_DATABASE_PASSWORD` | `database.password` | No |
124127
| `COCOINDEX_SOURCE_MAX_INFLIGHT_ROWS` | `default_execution_options.source_max_inflight_rows` | No |
128+
| `COCOINDEX_SOURCE_MAX_INFLIGHT_BYTES` | `default_execution_options.source_max_inflight_bytes` | No |

python/cocoindex/flow.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,7 @@ class _SourceRefreshOptions:
419419
@dataclass
420420
class _ExecutionOptions:
421421
max_inflight_rows: int | None = None
422+
max_inflight_bytes: int | None = None
422423

423424

424425
class FlowBuilder:
@@ -445,6 +446,7 @@ def add_source(
445446
name: str | None = None,
446447
refresh_interval: datetime.timedelta | None = None,
447448
max_inflight_rows: int | None = None,
449+
max_inflight_bytes: int | None = None,
448450
) -> DataSlice[T]:
449451
"""
450452
Import a source to the flow.
@@ -464,7 +466,10 @@ def add_source(
464466
_SourceRefreshOptions(refresh_interval=refresh_interval)
465467
),
466468
execution_options=dump_engine_object(
467-
_ExecutionOptions(max_inflight_rows=max_inflight_rows)
469+
_ExecutionOptions(
470+
max_inflight_rows=max_inflight_rows,
471+
max_inflight_bytes=max_inflight_bytes,
472+
)
468473
),
469474
),
470475
name,

python/cocoindex/setting.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class DefaultExecutionOptions:
4949

5050
# The maximum number of concurrent inflight requests.
5151
source_max_inflight_rows: int | None = 256
52+
source_max_inflight_bytes: int | None = 1024 * 1024 * 1024
5253

5354

5455
def _load_field(
@@ -103,6 +104,12 @@ def from_env(cls) -> Self:
103104
"COCOINDEX_SOURCE_MAX_INFLIGHT_ROWS",
104105
parse=int,
105106
)
107+
_load_field(
108+
exec_kwargs,
109+
"source_max_inflight_bytes",
110+
"COCOINDEX_SOURCE_MAX_INFLIGHT_BYTES",
111+
parse=int,
112+
)
106113
default_execution_options = DefaultExecutionOptions(**exec_kwargs)
107114

108115
app_namespace = os.getenv("COCOINDEX_APP_NAMESPACE", "")

src/base/spec.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,8 @@ impl SpecFormatter for OpSpec {
255255

256256
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
257257
pub struct ExecutionOptions {
258-
pub max_inflight_rows: Option<u32>,
258+
pub max_inflight_rows: Option<usize>,
259+
pub max_inflight_bytes: Option<usize>,
259260
}
260261

261262
#[derive(Debug, Clone, Serialize, Deserialize, Default)]

src/base/value.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ impl KeyValue {
361361
}
362362
}
363363

364-
pub fn estimated_detached_byte_size(&self) -> usize {
364+
fn estimated_detached_byte_size(&self) -> usize {
365365
match self {
366366
KeyValue::Bytes(v) => v.len(),
367367
KeyValue::Str(v) => v.len(),
@@ -568,7 +568,7 @@ impl BasicValue {
568568
}
569569

570570
/// Returns the estimated byte size of the value, for detached data (i.e. allocated on heap).
571-
pub fn estimated_detached_byte_size(&self) -> usize {
571+
fn estimated_detached_byte_size(&self) -> usize {
572572
fn json_estimated_detached_byte_size(val: &serde_json::Value) -> usize {
573573
match val {
574574
serde_json::Value::String(s) => s.len(),
@@ -862,7 +862,7 @@ impl Value<ScopeValue> {
862862
Value::Null => 0,
863863
Value::Basic(v) => v.estimated_detached_byte_size(),
864864
Value::Struct(v) => v.estimated_detached_byte_size(),
865-
(Value::UTable(v) | Value::LTable(v)) => {
865+
Value::UTable(v) | Value::LTable(v) => {
866866
v.iter()
867867
.map(|v| v.estimated_detached_byte_size())
868868
.sum::<usize>()
@@ -955,13 +955,17 @@ where
955955
}
956956

957957
impl FieldValues<ScopeValue> {
958-
pub fn estimated_detached_byte_size(&self) -> usize {
958+
fn estimated_detached_byte_size(&self) -> usize {
959959
self.fields
960960
.iter()
961961
.map(Value::estimated_byte_size)
962962
.sum::<usize>()
963963
+ self.fields.len() * std::mem::size_of::<Value<ScopeValue>>()
964964
}
965+
966+
pub fn estimated_byte_size(&self) -> usize {
967+
self.estimated_detached_byte_size() + std::mem::size_of::<Self>()
968+
}
965969
}
966970

967971
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]

src/builder/analyzer.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -695,6 +695,12 @@ impl AnalyzerContext {
695695
.default_execution_options
696696
.source_max_inflight_rows
697697
});
698+
let max_inflight_bytes =
699+
(import_op.spec.execution_options.max_inflight_bytes).or_else(|| {
700+
self.lib_ctx
701+
.default_execution_options
702+
.source_max_inflight_bytes
703+
});
698704
let result_fut = async move {
699705
trace!("Start building executor for source op `{}`", op_name);
700706
let executor = executor.await?;
@@ -705,7 +711,10 @@ impl AnalyzerContext {
705711
primary_key_type,
706712
name: op_name,
707713
refresh_options: import_op.spec.refresh_options,
708-
concurrency_controller: utils::ConcurrencyController::new(max_inflight_rows),
714+
concurrency_controller: concur_control::ConcurrencyController::new(
715+
max_inflight_rows,
716+
max_inflight_bytes,
717+
),
709718
})
710719
};
711720
Ok(result_fut)

src/builder/plan.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ pub struct AnalyzedImportOp {
5656
pub output: AnalyzedOpOutput,
5757
pub primary_key_type: schema::ValueType,
5858
pub refresh_options: spec::SourceRefreshOptions,
59-
pub concurrency_controller: utils::ConcurrencyController,
59+
60+
pub concurrency_controller: concur_control::ConcurrencyController,
6061
}
6162

6263
pub struct AnalyzedFunctionExecInfo {

src/execution/live_updater.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,15 @@ async fn update_source(
128128
});
129129
for change in change_msg.changes {
130130
let ack_fn = ack_fn.clone();
131+
let concur_permit = import_op
132+
.concurrency_controller
133+
.acquire(concur_control::BYTES_UNKNOWN_YET)
134+
.await?;
131135
tokio::spawn(source_context.clone().process_source_key(
132136
change.key,
133137
change.data,
134138
change_stream_stats.clone(),
139+
concur_permit,
135140
ack_fn.map(|ack_fn| {
136141
move || async move { SharedAckFn::ack(&ack_fn).await }
137142
}),

src/execution/source_indexer.rs

Lines changed: 79 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ impl SourceIndexingContext {
100100
key: value::KeyValue,
101101
source_data: Option<interface::SourceData>,
102102
update_stats: Arc<stats::UpdateStats>,
103+
_concur_permit: concur_control::ConcurrencyControllerPermit,
103104
ack_fn: Option<AckFn>,
104105
pool: PgPool,
105106
) {
@@ -160,66 +161,76 @@ impl SourceIndexingContext {
160161
}
161162
};
162163

163-
let permit = processing_sem.acquire().await?;
164-
let result = row_indexer::update_source_row(
165-
&SourceRowEvaluationContext {
166-
plan: &plan,
167-
import_op,
168-
schema,
169-
key: &key,
170-
import_op_idx: self.source_idx,
171-
},
172-
&self.setup_execution_ctx,
173-
source_data.value,
174-
&source_version,
175-
&pool,
176-
&update_stats,
177-
)
178-
.await?;
179-
let target_source_version = if let SkippedOr::Skipped(existing_source_version) = result
180164
{
181-
Some(existing_source_version)
182-
} else if source_version.kind == row_indexer::SourceVersionKind::NonExistence {
183-
Some(source_version)
184-
} else {
185-
None
186-
};
187-
if let Some(target_source_version) = target_source_version {
188-
let mut state = self.state.lock().unwrap();
189-
let scan_generation = state.scan_generation;
190-
let entry = state.rows.entry(key.clone());
191-
match entry {
192-
hash_map::Entry::Occupied(mut entry) => {
193-
if !entry
194-
.get()
195-
.source_version
196-
.should_skip(&target_source_version, None)
197-
{
198-
if target_source_version.kind
199-
== row_indexer::SourceVersionKind::NonExistence
165+
let _processing_permit = processing_sem.acquire().await?;
166+
let _concur_permit = match &source_data.value {
167+
interface::SourceValue::Existence(value) => {
168+
import_op
169+
.concurrency_controller
170+
.acquire_bytes_with_reservation(|| value.estimated_byte_size())
171+
.await?
172+
}
173+
interface::SourceValue::NonExistence => None,
174+
};
175+
let result = row_indexer::update_source_row(
176+
&SourceRowEvaluationContext {
177+
plan: &plan,
178+
import_op,
179+
schema,
180+
key: &key,
181+
import_op_idx: self.source_idx,
182+
},
183+
&self.setup_execution_ctx,
184+
source_data.value,
185+
&source_version,
186+
&pool,
187+
&update_stats,
188+
)
189+
.await?;
190+
let target_source_version =
191+
if let SkippedOr::Skipped(existing_source_version) = result {
192+
Some(existing_source_version)
193+
} else if source_version.kind == row_indexer::SourceVersionKind::NonExistence {
194+
Some(source_version)
195+
} else {
196+
None
197+
};
198+
if let Some(target_source_version) = target_source_version {
199+
let mut state = self.state.lock().unwrap();
200+
let scan_generation = state.scan_generation;
201+
let entry = state.rows.entry(key.clone());
202+
match entry {
203+
hash_map::Entry::Occupied(mut entry) => {
204+
if !entry
205+
.get()
206+
.source_version
207+
.should_skip(&target_source_version, None)
200208
{
201-
entry.remove();
202-
} else {
203-
let mut_entry = entry.get_mut();
204-
mut_entry.source_version = target_source_version;
205-
mut_entry.touched_generation = scan_generation;
209+
if target_source_version.kind
210+
== row_indexer::SourceVersionKind::NonExistence
211+
{
212+
entry.remove();
213+
} else {
214+
let mut_entry = entry.get_mut();
215+
mut_entry.source_version = target_source_version;
216+
mut_entry.touched_generation = scan_generation;
217+
}
206218
}
207219
}
208-
}
209-
hash_map::Entry::Vacant(entry) => {
210-
if target_source_version.kind
211-
!= row_indexer::SourceVersionKind::NonExistence
212-
{
213-
entry.insert(SourceRowIndexingState {
214-
source_version: target_source_version,
215-
touched_generation: scan_generation,
216-
..Default::default()
217-
});
220+
hash_map::Entry::Vacant(entry) => {
221+
if target_source_version.kind
222+
!= row_indexer::SourceVersionKind::NonExistence
223+
{
224+
entry.insert(SourceRowIndexingState {
225+
source_version: target_source_version,
226+
touched_generation: scan_generation,
227+
..Default::default()
228+
});
229+
}
218230
}
219231
}
220232
}
221233
}
222-
drop(permit);
223234
if let Some(ack_fn) = ack_fn {
224235
ack_fn().await?;
225236
}
@@ -243,6 +254,7 @@ impl SourceIndexingContext {
243254
key: value::KeyValue,
244255
source_version: SourceVersion,
245256
update_stats: &Arc<stats::UpdateStats>,
257+
concur_permit: concur_control::ConcurrencyControllerPermit,
246258
pool: &PgPool,
247259
) -> Option<impl Future<Output = ()> + Send + 'static> {
248260
{
@@ -257,10 +269,14 @@ impl SourceIndexingContext {
257269
return None;
258270
}
259271
}
260-
Some(
261-
self.clone()
262-
.process_source_key(key, None, update_stats.clone(), NO_ACK, pool.clone()),
263-
)
272+
Some(self.clone().process_source_key(
273+
key,
274+
None,
275+
update_stats.clone(),
276+
concur_permit,
277+
NO_ACK,
278+
pool.clone(),
279+
))
264280
}
265281

266282
pub async fn update(
@@ -282,15 +298,19 @@ impl SourceIndexingContext {
282298
state.scan_generation
283299
};
284300
while let Some(row) = rows_stream.next().await {
285-
let _ = import_op.concurrency_controller.acquire().await?;
286301
for row in row? {
302+
let concur_permit = import_op
303+
.concurrency_controller
304+
.acquire(concur_control::BYTES_UNKNOWN_YET)
305+
.await?;
287306
self.process_source_key_if_newer(
288307
row.key,
289308
SourceVersion::from_current_with_ordinal(
290309
row.ordinal
291310
.ok_or_else(|| anyhow::anyhow!("ordinal is not available"))?,
292311
),
293312
update_stats,
313+
concur_permit,
294314
pool,
295315
)
296316
.map(|fut| join_set.spawn(fut));
@@ -322,10 +342,12 @@ impl SourceIndexingContext {
322342
value: interface::SourceValue::NonExistence,
323343
ordinal: source_ordinal,
324344
});
345+
let concur_permit = import_op.concurrency_controller.acquire(Some(|| 0)).await?;
325346
join_set.spawn(self.clone().process_source_key(
326347
key,
327348
source_data,
328349
update_stats.clone(),
350+
concur_permit,
329351
NO_ACK,
330352
pool.clone(),
331353
));

0 commit comments

Comments
 (0)