|
1 | | -from mypyc.primitives.float_ops import int_to_float_op |
2 | | - |
3 | 1 | # Joins and Reducers |
4 | 2 |
|
5 | 3 | Join nodes synchronize and aggregate data from parallel execution paths. They use **Reducers** to combine multiple inputs into a single output. |
@@ -106,6 +104,51 @@ async def main(): |
106 | 104 |
|
107 | 105 | _(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ |
108 | 106 |
|
| 107 | +### `reduce_list_extend` |
| 108 | + |
| 109 | +[`reduce_list_extend`][pydantic_graph.beta.join.reduce_list_extend] extends a list with an iterable of items: |
| 110 | + |
| 111 | +```python {title="list_extend_reducer.py"} |
| 112 | +from dataclasses import dataclass |
| 113 | + |
| 114 | +from pydantic_graph.beta import GraphBuilder, StepContext |
| 115 | +from pydantic_graph.beta.join import reduce_list_extend |
| 116 | + |
| 117 | + |
| 118 | +@dataclass |
| 119 | +class SimpleState: |
| 120 | + pass |
| 121 | + |
| 122 | + |
| 123 | +async def main(): |
| 124 | + g = GraphBuilder(state_type=SimpleState, output_type=list[int]) |
| 125 | + |
| 126 | + @g.step |
| 127 | + async def generate(ctx: StepContext[SimpleState, None, None]) -> list[int]: |
| 128 | + return [1, 2, 3] |
| 129 | + |
| 130 | + @g.step |
| 131 | + async def create_range(ctx: StepContext[SimpleState, None, int]) -> list[int]: |
| 132 | + """Create a range from 0 to the input value.""" |
| 133 | + return list(range(ctx.inputs)) |
| 134 | + |
| 135 | + collect = g.join(reduce_list_extend, initial_factory=list[int]) |
| 136 | + |
| 137 | + g.add( |
| 138 | + g.edge_from(g.start_node).to(generate), |
| 139 | + g.edge_from(generate).map().to(create_range), |
| 140 | + g.edge_from(create_range).to(collect), |
| 141 | + g.edge_from(collect).to(g.end_node), |
| 142 | + ) |
| 143 | + |
| 144 | + graph = g.build() |
| 145 | + result = await graph.run(state=SimpleState()) |
| 146 | + print(sorted(result)) |
| 147 | + #> [0, 0, 0, 1, 1, 2] |
| 148 | +``` |
| 149 | + |
| 150 | +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ |
| 151 | + |
109 | 152 | ### `reduce_dict_update` |
110 | 153 |
|
111 | 154 | [`reduce_dict_update`][pydantic_graph.beta.join.reduce_dict_update] merges dictionaries together: |
@@ -203,6 +246,104 @@ async def main(): |
203 | 246 |
|
204 | 247 | _(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ |
205 | 248 |
|
| 249 | +### `reduce_sum` |
| 250 | + |
| 251 | +[`reduce_sum`][pydantic_graph.beta.join.reduce_sum] sums numeric values: |
| 252 | + |
| 253 | +```python {title="sum_reducer.py"} |
| 254 | +from dataclasses import dataclass |
| 255 | + |
| 256 | +from pydantic_graph.beta import GraphBuilder, StepContext |
| 257 | +from pydantic_graph.beta.join import reduce_sum |
| 258 | + |
| 259 | + |
| 260 | +@dataclass |
| 261 | +class SimpleState: |
| 262 | + pass |
| 263 | + |
| 264 | + |
| 265 | +async def main(): |
| 266 | + g = GraphBuilder(state_type=SimpleState, output_type=int) |
| 267 | + |
| 268 | + @g.step |
| 269 | + async def generate(ctx: StepContext[SimpleState, None, None]) -> list[int]: |
| 270 | + return [10, 20, 30, 40] |
| 271 | + |
| 272 | + @g.step |
| 273 | + async def identity(ctx: StepContext[SimpleState, None, int]) -> int: |
| 274 | + return ctx.inputs |
| 275 | + |
| 276 | + sum_join = g.join(reduce_sum, initial=0) |
| 277 | + |
| 278 | + g.add( |
| 279 | + g.edge_from(g.start_node).to(generate), |
| 280 | + g.edge_from(generate).map().to(identity), |
| 281 | + g.edge_from(identity).to(sum_join), |
| 282 | + g.edge_from(sum_join).to(g.end_node), |
| 283 | + ) |
| 284 | + |
| 285 | + graph = g.build() |
| 286 | + result = await graph.run(state=SimpleState()) |
| 287 | + print(result) |
| 288 | + #> 100 |
| 289 | +``` |
| 290 | + |
| 291 | +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ |
| 292 | + |
| 293 | +### `ReduceFirstValue` |
| 294 | + |
| 295 | +[`ReduceFirstValue`][pydantic_graph.beta.join.ReduceFirstValue] returns the first value it receives and cancels all other parallel tasks. This is useful for "race" scenarios where you want the first successful result: |
| 296 | + |
| 297 | +```python {title="first_value_reducer.py"} |
| 298 | +import asyncio |
| 299 | +from dataclasses import dataclass |
| 300 | + |
| 301 | +from pydantic_graph.beta import GraphBuilder, StepContext |
| 302 | +from pydantic_graph.beta.join import ReduceFirstValue |
| 303 | + |
| 304 | + |
| 305 | +@dataclass |
| 306 | +class SimpleState: |
| 307 | + tasks_completed: int = 0 |
| 308 | + |
| 309 | + |
| 310 | +async def main(): |
| 311 | + g = GraphBuilder(state_type=SimpleState, output_type=str) |
| 312 | + |
| 313 | + @g.step |
| 314 | + async def generate(ctx: StepContext[SimpleState, None, None]) -> list[int]: |
| 315 | + return [1, 2, 3, 4, 5] |
| 316 | + |
| 317 | + @g.step |
| 318 | + async def slow_process(ctx: StepContext[SimpleState, None, int]) -> str: |
| 319 | + """Simulate variable processing times.""" |
| 320 | + # Simulate different delays |
| 321 | + await asyncio.sleep(ctx.inputs * 0.1) |
| 322 | + ctx.state.tasks_completed += 1 |
| 323 | + return f'Result from task {ctx.inputs}' |
| 324 | + |
| 325 | + # Use ReduceFirstValue to get the first result and cancel the rest |
| 326 | + first_result = g.join(ReduceFirstValue[str](), initial=None, node_id='first_result') |
| 327 | + |
| 328 | + g.add( |
| 329 | + g.edge_from(g.start_node).to(generate), |
| 330 | + g.edge_from(generate).map().to(slow_process), |
| 331 | + g.edge_from(slow_process).to(first_result), |
| 332 | + g.edge_from(first_result).to(g.end_node), |
| 333 | + ) |
| 334 | + |
| 335 | + graph = g.build() |
| 336 | + state = SimpleState() |
| 337 | + result = await graph.run(state=state) |
| 338 | + |
| 339 | + print(result) |
| 340 | + #> Result from task 1 |
| 341 | + print(f'Tasks completed: {state.tasks_completed}') |
| 342 | + #> Tasks completed: 1 |
| 343 | +``` |
| 344 | + |
| 345 | +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ |
| 346 | + |
206 | 347 | ## Custom Reducers |
207 | 348 |
|
208 | 349 | Create custom reducers by defining a [`ReducerFunction`][pydantic_graph.beta.join.ReducerFunction]: |
@@ -331,6 +472,72 @@ async def main(): |
331 | 472 |
|
332 | 473 | _(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ |
333 | 474 |
|
| 475 | +### Canceling Sibling Tasks |
| 476 | + |
| 477 | +Reducers with access to [`ReducerContext`][pydantic_graph.beta.join.ReducerContext] can call [`ctx.cancel_sibling_tasks()`][pydantic_graph.beta.join.ReducerContext.cancel_sibling_tasks] to cancel all other parallel tasks in the same fork. This is useful for early termination when you've found what you need: |
| 478 | + |
| 479 | +```python {title="cancel_siblings.py"} |
| 480 | +import asyncio |
| 481 | +from dataclasses import dataclass |
| 482 | + |
| 483 | +from pydantic_graph.beta import GraphBuilder, StepContext |
| 484 | +from pydantic_graph.beta.join import ReducerContext |
| 485 | + |
| 486 | + |
| 487 | +@dataclass |
| 488 | +class SearchState: |
| 489 | + searches_completed: int = 0 |
| 490 | + |
| 491 | + |
| 492 | +def reduce_find_match(ctx: ReducerContext[SearchState, None], current: str | None, inputs: str) -> str | None: |
| 493 | + """Return the first input that contains 'target' and cancel remaining tasks.""" |
| 494 | + if current is not None: |
| 495 | + # We already found a match, ignore subsequent inputs |
| 496 | + return current |
| 497 | + if 'target' in inputs: |
| 498 | + # Found a match! Cancel all other parallel tasks |
| 499 | + ctx.cancel_sibling_tasks() |
| 500 | + return inputs |
| 501 | + return None |
| 502 | + |
| 503 | + |
| 504 | +async def main(): |
| 505 | + g = GraphBuilder(state_type=SearchState, output_type=str | None) |
| 506 | + |
| 507 | + @g.step |
| 508 | + async def generate_searches(ctx: StepContext[SearchState, None, None]) -> list[str]: |
| 509 | + return ['item1', 'item2', 'target_item', 'item4', 'item5'] |
| 510 | + |
| 511 | + @g.step |
| 512 | + async def search(ctx: StepContext[SearchState, None, str]) -> str: |
| 513 | + """Simulate a slow search operation.""" |
| 514 | + await asyncio.sleep(0.1) |
| 515 | + ctx.state.searches_completed += 1 |
| 516 | + return ctx.inputs |
| 517 | + |
| 518 | + find_match = g.join(reduce_find_match, initial=None) |
| 519 | + |
| 520 | + g.add( |
| 521 | + g.edge_from(g.start_node).to(generate_searches), |
| 522 | + g.edge_from(generate_searches).map().to(search), |
| 523 | + g.edge_from(search).to(find_match), |
| 524 | + g.edge_from(find_match).to(g.end_node), |
| 525 | + ) |
| 526 | + |
| 527 | + graph = g.build() |
| 528 | + state = SearchState() |
| 529 | + result = await graph.run(state=state) |
| 530 | + |
| 531 | + print(f'Found: {result}') |
| 532 | + #> Found: target_item |
| 533 | + print(f'Searches completed: {state.searches_completed}') |
| 534 | + #> Searches completed: 3 |
| 535 | +``` |
| 536 | + |
| 537 | +_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_ |
| 538 | + |
| 539 | +Note that only 3 searches completed instead of all 5, because the reducer canceled the remaining tasks after finding a match. |
| 540 | + |
334 | 541 | ## Multiple Joins |
335 | 542 |
|
336 | 543 | A graph can have multiple independent joins: |
|
0 commit comments