From cb939703e8032067bc0d5deaa85a7e8a3fe2bc2f Mon Sep 17 00:00:00 2001 From: Jiangzhou He Date: Mon, 7 Jul 2025 12:07:58 -0700 Subject: [PATCH 1/2] feat(flow-control): add bytes-based flow control for source --- docs/docs/core/flow_def.mdx | 1 + docs/docs/core/settings.mdx | 6 +- python/cocoindex/flow.py | 7 +- python/cocoindex/setting.py | 7 ++ src/base/spec.rs | 3 +- src/base/value.rs | 12 ++- src/builder/analyzer.rs | 11 ++- src/builder/plan.rs | 3 +- src/execution/live_updater.rs | 5 ++ src/execution/source_indexer.rs | 136 +++++++++++++++++++------------- src/prelude.rs | 2 +- src/settings.rs | 3 +- src/utils/concur_control.rs | 97 ++++++++++++++++++++--- src/utils/mod.rs | 4 +- 14 files changed, 217 insertions(+), 80 deletions(-) diff --git a/docs/docs/core/flow_def.mdx b/docs/docs/core/flow_def.mdx index 9b2d37074..41fd27f94 100644 --- a/docs/docs/core/flow_def.mdx +++ b/docs/docs/core/flow_def.mdx @@ -156,6 +156,7 @@ If nothing changed during the last refresh cycle, only list operations will be p You can pass the following arguments to `add_source()` to control the concurrency of the source operation: * `max_inflight_rows`: the maximum number of concurrent inflight requests for the source operation. +* `max_inflight_bytes`: the maximum number of concurrent inflight bytes for the source operation. The default value can be specified by [`DefaultExecutionOptions`](/docs/core/settings#defaultexecutionoptions) or corresponding [environment variable](/docs/core/settings#list-of-environment-variables). diff --git a/docs/docs/core/settings.mdx b/docs/docs/core/settings.mdx index 40f774888..1218211f3 100644 --- a/docs/docs/core/settings.mdx +++ b/docs/docs/core/settings.mdx @@ -109,7 +109,10 @@ If you use the Postgres database hosted by [Supabase](https://supabase.com/), pl `DefaultExecutionOptions` is used to configure the default execution options for the flow. It has the following fields: -* `source_max_inflight_rows` (type: `int`, optional): The maximum number of concurrent inflight requests for source operations. This only provides an default value, and can be overridden by the `max_inflight_rows` argument passed to `FlowBuilder.add_source()` on per-source basis. +* `source_max_inflight_rows` (type: `int`, optional): The maximum number of concurrent inflight requests for source operations. +* `source_max_inflight_bytes` (type: `int`, optional): The maximum number of concurrent inflight bytes for source operations. + +The options only provide a default value, and can be overridden by arguments passed to `FlowBuilder.add_source()` on per-source basis ([details](/docs/build/flow_def#concurrency-control)). ## List of Environment Variables @@ -122,3 +125,4 @@ This is the list of environment variables, each of which has a corresponding fie | `COCOINDEX_DATABASE_USER` | `database.user` | No | | `COCOINDEX_DATABASE_PASSWORD` | `database.password` | No | | `COCOINDEX_SOURCE_MAX_INFLIGHT_ROWS` | `default_execution_options.source_max_inflight_rows` | No | +| `COCOINDEX_SOURCE_MAX_INFLIGHT_BYTES` | `default_execution_options.source_max_inflight_bytes` | No | diff --git a/python/cocoindex/flow.py b/python/cocoindex/flow.py index e05e7478b..498e068cb 100644 --- a/python/cocoindex/flow.py +++ b/python/cocoindex/flow.py @@ -419,6 +419,7 @@ class _SourceRefreshOptions: @dataclass class _ExecutionOptions: max_inflight_rows: int | None = None + max_inflight_bytes: int | None = None class FlowBuilder: @@ -445,6 +446,7 @@ def add_source( name: str | None = None, refresh_interval: datetime.timedelta | None = None, max_inflight_rows: int | None = None, + max_inflight_bytes: int | None = None, ) -> DataSlice[T]: """ Import a source to the flow. @@ -464,7 +466,10 @@ def add_source( _SourceRefreshOptions(refresh_interval=refresh_interval) ), execution_options=dump_engine_object( - _ExecutionOptions(max_inflight_rows=max_inflight_rows) + _ExecutionOptions( + max_inflight_rows=max_inflight_rows, + max_inflight_bytes=max_inflight_bytes, + ) ), ), name, diff --git a/python/cocoindex/setting.py b/python/cocoindex/setting.py index cab71ad3b..aca144a57 100644 --- a/python/cocoindex/setting.py +++ b/python/cocoindex/setting.py @@ -49,6 +49,7 @@ class DefaultExecutionOptions: # The maximum number of concurrent inflight requests. source_max_inflight_rows: int | None = 256 + source_max_inflight_bytes: int | None = 1024 * 1024 * 1024 def _load_field( @@ -103,6 +104,12 @@ def from_env(cls) -> Self: "COCOINDEX_SOURCE_MAX_INFLIGHT_ROWS", parse=int, ) + _load_field( + exec_kwargs, + "source_max_inflight_bytes", + "COCOINDEX_SOURCE_MAX_INFLIGHT_BYTES", + parse=int, + ) default_execution_options = DefaultExecutionOptions(**exec_kwargs) app_namespace = os.getenv("COCOINDEX_APP_NAMESPACE", "") diff --git a/src/base/spec.rs b/src/base/spec.rs index da3684888..8ecf3b232 100644 --- a/src/base/spec.rs +++ b/src/base/spec.rs @@ -255,7 +255,8 @@ impl SpecFormatter for OpSpec { #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct ExecutionOptions { - pub max_inflight_rows: Option, + pub max_inflight_rows: Option, + pub max_inflight_bytes: Option, } #[derive(Debug, Clone, Serialize, Deserialize, Default)] diff --git a/src/base/value.rs b/src/base/value.rs index 29e96cea1..49c7d949c 100644 --- a/src/base/value.rs +++ b/src/base/value.rs @@ -361,7 +361,7 @@ impl KeyValue { } } - pub fn estimated_detached_byte_size(&self) -> usize { + fn estimated_detached_byte_size(&self) -> usize { match self { KeyValue::Bytes(v) => v.len(), KeyValue::Str(v) => v.len(), @@ -568,7 +568,7 @@ impl BasicValue { } /// Returns the estimated byte size of the value, for detached data (i.e. allocated on heap). - pub fn estimated_detached_byte_size(&self) -> usize { + fn estimated_detached_byte_size(&self) -> usize { fn json_estimated_detached_byte_size(val: &serde_json::Value) -> usize { match val { serde_json::Value::String(s) => s.len(), @@ -862,7 +862,7 @@ impl Value { Value::Null => 0, Value::Basic(v) => v.estimated_detached_byte_size(), Value::Struct(v) => v.estimated_detached_byte_size(), - (Value::UTable(v) | Value::LTable(v)) => { + Value::UTable(v) | Value::LTable(v) => { v.iter() .map(|v| v.estimated_detached_byte_size()) .sum::() @@ -955,13 +955,17 @@ where } impl FieldValues { - pub fn estimated_detached_byte_size(&self) -> usize { + fn estimated_detached_byte_size(&self) -> usize { self.fields .iter() .map(Value::estimated_byte_size) .sum::() + self.fields.len() * std::mem::size_of::>() } + + pub fn estimated_byte_size(&self) -> usize { + self.estimated_detached_byte_size() + std::mem::size_of::() + } } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] diff --git a/src/builder/analyzer.rs b/src/builder/analyzer.rs index 6366a5482..217f13513 100644 --- a/src/builder/analyzer.rs +++ b/src/builder/analyzer.rs @@ -695,6 +695,12 @@ impl AnalyzerContext { .default_execution_options .source_max_inflight_rows }); + let max_inflight_bytes = + (import_op.spec.execution_options.max_inflight_bytes).or_else(|| { + self.lib_ctx + .default_execution_options + .source_max_inflight_bytes + }); let result_fut = async move { trace!("Start building executor for source op `{}`", op_name); let executor = executor.await?; @@ -705,7 +711,10 @@ impl AnalyzerContext { primary_key_type, name: op_name, refresh_options: import_op.spec.refresh_options, - concurrency_controller: utils::ConcurrencyController::new(max_inflight_rows), + concurrency_controller: concur_control::ConcurrencyController::new( + max_inflight_rows, + max_inflight_bytes, + ), }) }; Ok(result_fut) diff --git a/src/builder/plan.rs b/src/builder/plan.rs index 9e487f823..7bf554192 100644 --- a/src/builder/plan.rs +++ b/src/builder/plan.rs @@ -56,7 +56,8 @@ pub struct AnalyzedImportOp { pub output: AnalyzedOpOutput, pub primary_key_type: schema::ValueType, pub refresh_options: spec::SourceRefreshOptions, - pub concurrency_controller: utils::ConcurrencyController, + + pub concurrency_controller: concur_control::ConcurrencyController, } pub struct AnalyzedFunctionExecInfo { diff --git a/src/execution/live_updater.rs b/src/execution/live_updater.rs index a2a236921..43dc6b369 100644 --- a/src/execution/live_updater.rs +++ b/src/execution/live_updater.rs @@ -128,10 +128,15 @@ async fn update_source( }); for change in change_msg.changes { let ack_fn = ack_fn.clone(); + let concur_permit = import_op + .concurrency_controller + .acquire(concur_control::BYTES_UNKNOWN_YET) + .await?; tokio::spawn(source_context.clone().process_source_key( change.key, change.data, change_stream_stats.clone(), + concur_permit, ack_fn.map(|ack_fn| { move || async move { SharedAckFn::ack(&ack_fn).await } }), diff --git a/src/execution/source_indexer.rs b/src/execution/source_indexer.rs index c8a6f6884..11451bb6f 100644 --- a/src/execution/source_indexer.rs +++ b/src/execution/source_indexer.rs @@ -100,6 +100,7 @@ impl SourceIndexingContext { key: value::KeyValue, source_data: Option, update_stats: Arc, + _concur_permit: concur_control::ConcurrencyControllerPermit, ack_fn: Option, pool: PgPool, ) { @@ -160,66 +161,76 @@ impl SourceIndexingContext { } }; - let permit = processing_sem.acquire().await?; - let result = row_indexer::update_source_row( - &SourceRowEvaluationContext { - plan: &plan, - import_op, - schema, - key: &key, - import_op_idx: self.source_idx, - }, - &self.setup_execution_ctx, - source_data.value, - &source_version, - &pool, - &update_stats, - ) - .await?; - let target_source_version = if let SkippedOr::Skipped(existing_source_version) = result { - Some(existing_source_version) - } else if source_version.kind == row_indexer::SourceVersionKind::NonExistence { - Some(source_version) - } else { - None - }; - if let Some(target_source_version) = target_source_version { - let mut state = self.state.lock().unwrap(); - let scan_generation = state.scan_generation; - let entry = state.rows.entry(key.clone()); - match entry { - hash_map::Entry::Occupied(mut entry) => { - if !entry - .get() - .source_version - .should_skip(&target_source_version, None) - { - if target_source_version.kind - == row_indexer::SourceVersionKind::NonExistence + let _processing_permit = processing_sem.acquire().await?; + let _concur_permit = match &source_data.value { + interface::SourceValue::Existence(value) => { + import_op + .concurrency_controller + .acquire_bytes_with_reservation(|| value.estimated_byte_size()) + .await? + } + interface::SourceValue::NonExistence => None, + }; + let result = row_indexer::update_source_row( + &SourceRowEvaluationContext { + plan: &plan, + import_op, + schema, + key: &key, + import_op_idx: self.source_idx, + }, + &self.setup_execution_ctx, + source_data.value, + &source_version, + &pool, + &update_stats, + ) + .await?; + let target_source_version = + if let SkippedOr::Skipped(existing_source_version) = result { + Some(existing_source_version) + } else if source_version.kind == row_indexer::SourceVersionKind::NonExistence { + Some(source_version) + } else { + None + }; + if let Some(target_source_version) = target_source_version { + let mut state = self.state.lock().unwrap(); + let scan_generation = state.scan_generation; + let entry = state.rows.entry(key.clone()); + match entry { + hash_map::Entry::Occupied(mut entry) => { + if !entry + .get() + .source_version + .should_skip(&target_source_version, None) { - entry.remove(); - } else { - let mut_entry = entry.get_mut(); - mut_entry.source_version = target_source_version; - mut_entry.touched_generation = scan_generation; + if target_source_version.kind + == row_indexer::SourceVersionKind::NonExistence + { + entry.remove(); + } else { + let mut_entry = entry.get_mut(); + mut_entry.source_version = target_source_version; + mut_entry.touched_generation = scan_generation; + } } } - } - hash_map::Entry::Vacant(entry) => { - if target_source_version.kind - != row_indexer::SourceVersionKind::NonExistence - { - entry.insert(SourceRowIndexingState { - source_version: target_source_version, - touched_generation: scan_generation, - ..Default::default() - }); + hash_map::Entry::Vacant(entry) => { + if target_source_version.kind + != row_indexer::SourceVersionKind::NonExistence + { + entry.insert(SourceRowIndexingState { + source_version: target_source_version, + touched_generation: scan_generation, + ..Default::default() + }); + } } } } } - drop(permit); if let Some(ack_fn) = ack_fn { ack_fn().await?; } @@ -243,6 +254,7 @@ impl SourceIndexingContext { key: value::KeyValue, source_version: SourceVersion, update_stats: &Arc, + concur_permit: concur_control::ConcurrencyControllerPermit, pool: &PgPool, ) -> Option + Send + 'static> { { @@ -257,10 +269,14 @@ impl SourceIndexingContext { return None; } } - Some( - self.clone() - .process_source_key(key, None, update_stats.clone(), NO_ACK, pool.clone()), - ) + Some(self.clone().process_source_key( + key, + None, + update_stats.clone(), + concur_permit, + NO_ACK, + pool.clone(), + )) } pub async fn update( @@ -282,8 +298,11 @@ impl SourceIndexingContext { state.scan_generation }; while let Some(row) = rows_stream.next().await { - let _ = import_op.concurrency_controller.acquire().await?; for row in row? { + let concur_permit = import_op + .concurrency_controller + .acquire(concur_control::BYTES_UNKNOWN_YET) + .await?; self.process_source_key_if_newer( row.key, SourceVersion::from_current_with_ordinal( @@ -291,6 +310,7 @@ impl SourceIndexingContext { .ok_or_else(|| anyhow::anyhow!("ordinal is not available"))?, ), update_stats, + concur_permit, pool, ) .map(|fut| join_set.spawn(fut)); @@ -322,10 +342,12 @@ impl SourceIndexingContext { value: interface::SourceValue::NonExistence, ordinal: source_ordinal, }); + let concur_permit = import_op.concurrency_controller.acquire(Some(|| 0)).await?; join_set.spawn(self.clone().process_source_key( key, source_data, update_stats.clone(), + concur_permit, NO_ACK, pool.clone(), )); diff --git a/src/prelude.rs b/src/prelude.rs index e5ad78f27..5f5f0f365 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -26,7 +26,7 @@ pub(crate) use crate::ops::interface; pub(crate) use crate::service::error::{ApiError, invariance_violation}; pub(crate) use crate::setup; pub(crate) use crate::setup::AuthRegistry; -pub(crate) use crate::utils::{self, retryable}; +pub(crate) use crate::utils::{self, concur_control, retryable}; pub(crate) use crate::{api_bail, api_error}; pub(crate) use anyhow::{anyhow, bail}; diff --git a/src/settings.rs b/src/settings.rs index d3acd4b1f..e79e9af19 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -9,7 +9,8 @@ pub struct DatabaseConnectionSpec { #[derive(Deserialize, Debug, Default)] pub struct DefaultExecutionOptions { - pub source_max_inflight_rows: Option, + pub source_max_inflight_rows: Option, + pub source_max_inflight_bytes: Option, } #[derive(Deserialize, Debug, Default)] diff --git a/src/utils/concur_control.rs b/src/utils/concur_control.rs index eec06a7f0..6b533447f 100644 --- a/src/utils/concur_control.rs +++ b/src/utils/concur_control.rs @@ -1,30 +1,109 @@ use crate::prelude::*; -use tokio::sync::{Semaphore, SemaphorePermit}; +use tokio::sync::{AcquireError, OwnedSemaphorePermit, Semaphore}; -pub struct ConcurrencyController { - inflight_count_sem: Option, +struct WeightedSemaphore { + downscale_factor: u8, + downscaled_quota: u32, + sem: Arc, +} + +impl WeightedSemaphore { + pub fn new(quota: usize) -> Self { + let mut downscale_factor = 0; + let mut downscaled_quota = quota; + while downscaled_quota > u32::MAX as usize { + downscaled_quota >>= 1; + downscale_factor += 1; + } + let sem = Arc::new(Semaphore::new(downscaled_quota)); + Self { + downscaled_quota: downscaled_quota as u32, + downscale_factor, + sem, + } + } + + async fn acquire_reservation(&self) -> Result { + self.sem.clone().acquire_owned().await + } + + async fn acquire<'a>( + &'a self, + weight: usize, + reserved: bool, + ) -> Result, AcquireError> { + let downscaled_weight = (weight >> self.downscale_factor) as u32; + let capped_weight = downscaled_weight.min(self.downscaled_quota); + let reserved_weight = if reserved { 1 } else { 0 }; + if reserved_weight >= capped_weight { + return Ok(None); + } + Ok(Some( + self.sem + .clone() + .acquire_many_owned(capped_weight - reserved_weight) + .await?, + )) + } } -pub struct ConcurrencyControllerPermit<'a> { - _inflight_count_permit: Option>, +pub struct ConcurrencyControllerPermit { + _inflight_count_permit: Option, + _inflight_bytes_permit: Option, } +pub struct ConcurrencyController { + inflight_count_sem: Option>, + inflight_bytes_sem: Option, +} + +pub static BYTES_UNKNOWN_YET: Option usize> = None; + impl ConcurrencyController { - pub fn new(max_inflight_count: Option) -> Self { + pub fn new(max_inflight_count: Option, max_inflight_bytes: Option) -> Self { Self { - inflight_count_sem: max_inflight_count.map(|max| Semaphore::new(max as usize)), + inflight_count_sem: max_inflight_count.map(|max| Arc::new(Semaphore::new(max))), + inflight_bytes_sem: max_inflight_bytes.map(|max| WeightedSemaphore::new(max)), } } - pub async fn acquire<'a>(&'a self) -> Result> { + /// If `bytes_fn` is `None`, it means the number of bytes is not known yet. + /// The controller will reserve a minimum number of bytes. + /// The caller should call `acquire_bytes_with_reservation` with the actual number of bytes later. + pub async fn acquire( + &self, + bytes_fn: Option usize>, + ) -> Result { let inflight_count_permit = if let Some(sem) = &self.inflight_count_sem { - Some(sem.acquire().await?) + Some(sem.clone().acquire_owned().await?) + } else { + None + }; + let inflight_bytes_permit = if let Some(sem) = &self.inflight_bytes_sem { + if let Some(bytes_fn) = bytes_fn { + let n = bytes_fn(); + sem.acquire(n, false).await? + } else { + Some(sem.acquire_reservation().await?) + } } else { None }; Ok(ConcurrencyControllerPermit { _inflight_count_permit: inflight_count_permit, + _inflight_bytes_permit: inflight_bytes_permit, }) } + + pub async fn acquire_bytes_with_reservation<'a>( + &'a self, + bytes_fn: impl FnOnce() -> usize, + ) -> Result, AcquireError> { + if let Some(sem) = &self.inflight_bytes_sem { + sem.acquire(bytes_fn(), true).await + } else { + Ok(None) + } + } } diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 81c5e38e8..a13e05b8c 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,8 +1,6 @@ +pub mod concur_control; pub mod db; pub mod fingerprint; pub mod immutable; pub mod retryable; pub mod yaml_ser; - -mod concur_control; -pub use concur_control::ConcurrencyController; From aa761ad3707bd7498887232a88b40f1a2d7a2636 Mon Sep 17 00:00:00 2001 From: Jiangzhou He Date: Mon, 7 Jul 2025 16:52:48 -0700 Subject: [PATCH 2/2] docs: fix broken links --- docs/docs/core/settings.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/docs/core/settings.mdx b/docs/docs/core/settings.mdx index 1218211f3..bcee1105d 100644 --- a/docs/docs/core/settings.mdx +++ b/docs/docs/core/settings.mdx @@ -112,7 +112,7 @@ If you use the Postgres database hosted by [Supabase](https://supabase.com/), pl * `source_max_inflight_rows` (type: `int`, optional): The maximum number of concurrent inflight requests for source operations. * `source_max_inflight_bytes` (type: `int`, optional): The maximum number of concurrent inflight bytes for source operations. -The options only provide a default value, and can be overridden by arguments passed to `FlowBuilder.add_source()` on per-source basis ([details](/docs/build/flow_def#concurrency-control)). +The options provide default values, and can be overridden by arguments passed to `FlowBuilder.add_source()` on per-source basis ([details](/docs/core/flow_def#concurrency-control)). ## List of Environment Variables