Skip to content

Commit 2c6a30e

Browse files
authored
feat(flow-control): support flow control for "for each" operations (#712)
* feat(flow-control): support flow control for "for each" operations * docs: clarify
1 parent fb39083 commit 2c6a30e

File tree

8 files changed

+175
-64
lines changed

8 files changed

+175
-64
lines changed

docs/docs/core/flow_def.mdx

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,16 @@ You can pass the following arguments to `add_source()` to control the concurrenc
158158
* `max_inflight_rows`: the maximum number of concurrent inflight requests for the source operation.
159159
* `max_inflight_bytes`: the maximum number of concurrent inflight bytes for the source operation.
160160

161+
For example:
162+
163+
```py
164+
@cocoindex.flow_def(name="DemoFlow")
165+
def demo_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope):
166+
data_scope["documents"] = flow_builder.add_source(
167+
DemoSourceSpec(...), max_inflight_rows=10, max_inflight_bytes=100*1024*1024)
168+
......
169+
```
170+
161171
The default value can be specified by [`DefaultExecutionOptions`](/docs/core/settings#defaultexecutionoptions) or corresponding [environment variable](/docs/core/settings#list-of-environment-variables).
162172

163173
### Transform
@@ -204,6 +214,25 @@ def demo_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataSco
204214
</TabItem>
205215
</Tabs>
206216

217+
#### Concurrency control
218+
219+
You can pass the following arguments to `row()` to control the concurrency of the for-each operation:
220+
221+
* `max_inflight_rows`: the maximum number of concurrent inflight requests for the for-each operation.
222+
* `max_inflight_bytes`: the maximum number of concurrent inflight bytes for the for-each operation.
223+
We only take the number of bytes from this row before this for-each operation into account.
224+
225+
For example:
226+
227+
```python
228+
@cocoindex.flow_def(name="DemoFlow")
229+
def demo_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope):
230+
...
231+
with data_scope["table1"].row(max_inflight_rows=10, max_inflight_bytes=10*1024*1024) as table1_row:
232+
# Children operations
233+
table1_row["field2"] = table1_row["field1"].transform(DemoFunctionSpec(...))
234+
```
235+
207236
### Get a sub field
208237

209238
If the data slice has `Struct` type, you can obtain a data slice on a specific sub field of it, similar to getting a field of a data scope.

python/cocoindex/flow.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -198,18 +198,42 @@ def __getitem__(self, field_name: str) -> DataSlice[T]:
198198
raise KeyError(field_name)
199199
return DataSlice(_DataSliceState(self._state.flow_builder_state, field_slice))
200200

201-
def row(self) -> DataScope:
201+
def row(
202+
self,
203+
/,
204+
*,
205+
max_inflight_rows: int | None = None,
206+
max_inflight_bytes: int | None = None,
207+
) -> DataScope:
202208
"""
203209
Return a scope representing each row of the table.
204210
"""
205-
row_scope = self._state.engine_data_slice.table_row_scope()
211+
row_scope = self._state.flow_builder_state.engine_flow_builder.for_each(
212+
self._state.engine_data_slice,
213+
execution_options=dump_engine_object(
214+
_ExecutionOptions(
215+
max_inflight_rows=max_inflight_rows,
216+
max_inflight_bytes=max_inflight_bytes,
217+
),
218+
),
219+
)
206220
return DataScope(self._state.flow_builder_state, row_scope)
207221

208-
def for_each(self, f: Callable[[DataScope], None]) -> None:
222+
def for_each(
223+
self,
224+
f: Callable[[DataScope], None],
225+
/,
226+
*,
227+
max_inflight_rows: int | None = None,
228+
max_inflight_bytes: int | None = None,
229+
) -> None:
209230
"""
210231
Apply a function to each row of the collection.
211232
"""
212-
with self.row() as scope:
233+
with self.row(
234+
max_inflight_rows=max_inflight_rows,
235+
max_inflight_bytes=max_inflight_bytes,
236+
) as scope:
213237
f(scope)
214238

215239
def transform(

src/base/spec.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,10 @@ impl SpecFormatter for OpSpec {
255255

256256
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
257257
pub struct ExecutionOptions {
258+
#[serde(default, skip_serializing_if = "Option::is_none")]
258259
pub max_inflight_rows: Option<usize>,
260+
261+
#[serde(default, skip_serializing_if = "Option::is_none")]
259262
pub max_inflight_bytes: Option<usize>,
260263
}
261264

@@ -327,6 +330,9 @@ pub struct ForEachOpSpec {
327330
/// Mapping that provides a table to apply reactive operations to.
328331
pub field_path: FieldPath,
329332
pub op_scope: ReactiveOpScope,
333+
334+
#[serde(default)]
335+
pub execution_options: ExecutionOptions,
330336
}
331337

332338
impl ForEachOpSpec {

src/base/value.rs

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,14 @@ use serde::{
1414
};
1515
use std::{collections::BTreeMap, ops::Deref, sync::Arc};
1616

17+
pub trait EstimatedByteSize: Sized {
18+
fn estimated_detached_byte_size(&self) -> usize;
19+
20+
fn estimated_byte_size(&self) -> usize {
21+
self.estimated_detached_byte_size() + std::mem::size_of::<Self>()
22+
}
23+
}
24+
1725
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
1826
pub struct RangeValue {
1927
pub start: usize,
@@ -855,7 +863,7 @@ impl<VS> Value<VS> {
855863
}
856864
}
857865

858-
impl Value<ScopeValue> {
866+
impl<VS: EstimatedByteSize> Value<VS> {
859867
pub fn estimated_byte_size(&self) -> usize {
860868
std::mem::size_of::<Self>()
861869
+ match self {
@@ -885,6 +893,16 @@ pub struct FieldValues<VS = ScopeValue> {
885893
pub fields: Vec<Value<VS>>,
886894
}
887895

896+
impl<VS: EstimatedByteSize> EstimatedByteSize for FieldValues<VS> {
897+
fn estimated_detached_byte_size(&self) -> usize {
898+
self.fields
899+
.iter()
900+
.map(Value::<VS>::estimated_byte_size)
901+
.sum::<usize>()
902+
+ self.fields.len() * std::mem::size_of::<Value<VS>>()
903+
}
904+
}
905+
888906
impl serde::Serialize for FieldValues {
889907
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
890908
self.fields.serialize(serializer)
@@ -954,23 +972,15 @@ where
954972
}
955973
}
956974

957-
impl FieldValues<ScopeValue> {
958-
fn estimated_detached_byte_size(&self) -> usize {
959-
self.fields
960-
.iter()
961-
.map(Value::estimated_byte_size)
962-
.sum::<usize>()
963-
+ self.fields.len() * std::mem::size_of::<Value<ScopeValue>>()
964-
}
975+
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
976+
pub struct ScopeValue(pub FieldValues);
965977

966-
pub fn estimated_byte_size(&self) -> usize {
967-
self.estimated_detached_byte_size() + std::mem::size_of::<Self>()
978+
impl EstimatedByteSize for ScopeValue {
979+
fn estimated_detached_byte_size(&self) -> usize {
980+
self.0.estimated_detached_byte_size()
968981
}
969982
}
970983

971-
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
972-
pub struct ScopeValue(pub FieldValues);
973-
974984
impl Deref for ScopeValue {
975985
type Target = FieldValues;
976986

src/builder/analyzer.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -807,13 +807,19 @@ impl AnalyzerContext {
807807
analyzed_op_scope_fut
808808
};
809809
let op_name = reactive_op.name.clone();
810+
811+
let exec_options = foreach_op.execution_options.clone();
810812
async move {
811813
Ok(AnalyzedReactiveOp::ForEach(AnalyzedForEachOp {
812814
local_field_ref,
813815
op_scope: analyzed_op_scope_fut
814816
.await
815817
.with_context(|| format!("Analyzing foreach op: {op_name}"))?,
816818
name: op_name,
819+
concurrency_controller: concur_control::ConcurrencyController::new(
820+
exec_options.max_inflight_rows,
821+
exec_options.max_inflight_bytes,
822+
),
817823
}))
818824
}
819825
.boxed()

src/builder/flow_builder.rs

Lines changed: 57 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -156,24 +156,6 @@ impl DataSlice {
156156
data_type: field_schema.value_type.clone().into(),
157157
}))
158158
}
159-
160-
pub fn table_row_scope(&self) -> PyResult<OpScopeRef> {
161-
let field_path = match self.value.as_ref() {
162-
spec::ValueMapping::Field(v) => &v.field_path,
163-
_ => return Err(PyException::new_err("expect field path")),
164-
};
165-
let num_parent_layers = self.scope.ancestors().count();
166-
let scope_name = format!(
167-
"{}_{}",
168-
field_path.last().map_or("", |s| s.as_str()),
169-
num_parent_layers
170-
);
171-
let (_, sub_op_scope) = self
172-
.scope
173-
.new_foreach_op_scope(scope_name, field_path)
174-
.into_py_result()?;
175-
Ok(OpScopeRef(sub_op_scope))
176-
}
177159
}
178160

179161
impl DataSlice {
@@ -383,6 +365,48 @@ impl FlowBuilder {
383365
Ok(())
384366
}
385367

368+
#[pyo3(signature = (data_slice, execution_options=None))]
369+
pub fn for_each(
370+
&mut self,
371+
data_slice: DataSlice,
372+
execution_options: Option<py::Pythonized<spec::ExecutionOptions>>,
373+
) -> PyResult<OpScopeRef> {
374+
let parent_scope = &data_slice.scope;
375+
let field_path = match data_slice.value.as_ref() {
376+
spec::ValueMapping::Field(v) => &v.field_path,
377+
_ => return Err(PyException::new_err("expect field path")),
378+
};
379+
let num_parent_layers = parent_scope.ancestors().count();
380+
let scope_name = format!(
381+
"{}_{}",
382+
field_path.last().map_or("", |s| s.as_str()),
383+
num_parent_layers
384+
);
385+
let (_, child_op_scope) = parent_scope
386+
.new_foreach_op_scope(scope_name.clone(), field_path)
387+
.into_py_result()?;
388+
389+
let reactive_op = spec::NamedSpec {
390+
name: format!(".for_each.{}", self.next_generated_op_id),
391+
spec: spec::ReactiveOpSpec::ForEach(spec::ForEachOpSpec {
392+
field_path: field_path.clone(),
393+
op_scope: spec::ReactiveOpScope {
394+
name: scope_name,
395+
ops: vec![],
396+
},
397+
execution_options: execution_options
398+
.map(|o| o.into_inner())
399+
.unwrap_or_default(),
400+
}),
401+
};
402+
self.next_generated_op_id += 1;
403+
self.get_mut_reactive_ops(parent_scope)
404+
.into_py_result()?
405+
.push(reactive_op);
406+
407+
Ok(OpScopeRef(child_op_scope))
408+
}
409+
386410
#[pyo3(signature = (kind, op_spec, args, target_scope, name))]
387411
pub fn transform(
388412
&mut self,
@@ -428,7 +452,9 @@ impl FlowBuilder {
428452
.into_py_result()?;
429453
std::mem::drop(analyzed);
430454

431-
self.get_mut_reactive_ops(op_scope).push(reactive_op);
455+
self.get_mut_reactive_ops(op_scope)
456+
.into_py_result()?
457+
.push(reactive_op);
432458

433459
let result = Self::last_field_to_data_slice(op_scope).into_py_result()?;
434460
Ok(result)
@@ -476,7 +502,9 @@ impl FlowBuilder {
476502
.into_py_result()?;
477503
std::mem::drop(analyzed);
478504

479-
self.get_mut_reactive_ops(common_scope).push(reactive_op);
505+
self.get_mut_reactive_ops(common_scope)
506+
.into_py_result()?
507+
.push(reactive_op);
480508

481509
let collector_schema = CollectorSchema::from_fields(
482510
fields
@@ -741,27 +769,19 @@ impl FlowBuilder {
741769
fn get_mut_reactive_ops<'a>(
742770
&'a mut self,
743771
op_scope: &OpScope,
744-
) -> &'a mut Vec<spec::NamedSpec<spec::ReactiveOpSpec>> {
745-
Self::get_mut_reactive_ops_internal(
746-
op_scope,
747-
&mut self.reactive_ops,
748-
&mut self.next_generated_op_id,
749-
)
772+
) -> Result<&'a mut Vec<spec::NamedSpec<spec::ReactiveOpSpec>>> {
773+
Self::get_mut_reactive_ops_internal(op_scope, &mut self.reactive_ops)
750774
}
751775

752776
fn get_mut_reactive_ops_internal<'a>(
753777
op_scope: &OpScope,
754778
root_reactive_ops: &'a mut Vec<spec::NamedSpec<spec::ReactiveOpSpec>>,
755-
next_generated_op_id: &mut usize,
756-
) -> &'a mut Vec<spec::NamedSpec<spec::ReactiveOpSpec>> {
757-
match &op_scope.parent {
779+
) -> Result<&'a mut Vec<spec::NamedSpec<spec::ReactiveOpSpec>>> {
780+
let result = match &op_scope.parent {
758781
None => root_reactive_ops,
759782
Some((parent_op_scope, field_path)) => {
760-
let parent_reactive_ops = Self::get_mut_reactive_ops_internal(
761-
parent_op_scope,
762-
root_reactive_ops,
763-
next_generated_op_id,
764-
);
783+
let parent_reactive_ops =
784+
Self::get_mut_reactive_ops_internal(parent_op_scope, root_reactive_ops)?;
765785
// Reuse the last foreach if matched, otherwise create a new one.
766786
match parent_reactive_ops.last() {
767787
Some(spec::NamedSpec {
@@ -771,24 +791,15 @@ impl FlowBuilder {
771791
&& foreach_spec.op_scope.name == op_scope.name => {}
772792

773793
_ => {
774-
parent_reactive_ops.push(spec::NamedSpec {
775-
name: format!(".foreach.{}", next_generated_op_id),
776-
spec: spec::ReactiveOpSpec::ForEach(spec::ForEachOpSpec {
777-
field_path: field_path.clone(),
778-
op_scope: spec::ReactiveOpScope {
779-
name: op_scope.name.clone(),
780-
ops: vec![],
781-
},
782-
}),
783-
});
784-
*next_generated_op_id += 1;
794+
api_bail!("already out of op scope `{}`", op_scope.name);
785795
}
786796
}
787797
match &mut parent_reactive_ops.last_mut().unwrap().spec {
788798
spec::ReactiveOpSpec::ForEach(foreach_spec) => &mut foreach_spec.op_scope.ops,
789799
_ => unreachable!(),
790800
}
791801
}
792-
}
802+
};
803+
Ok(result)
793804
}
794805
}

src/builder/plan.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ pub struct AnalyzedForEachOp {
8282
pub name: String,
8383
pub local_field_ref: AnalyzedLocalFieldReference,
8484
pub op_scope: AnalyzedOpScope,
85+
pub concurrency_controller: concur_control::ConcurrencyController,
8586
}
8687

8788
pub struct AnalyzedCollectOp {

0 commit comments

Comments
 (0)