1+ from mypyc.primitives.float_ops import int_to_float_op
2+
13# Joins and Reducers
24
35Join nodes synchronize and aggregate data from parallel execution paths. They use ** Reducers** to combine multiple inputs into a single output.
@@ -12,12 +14,13 @@ When you use [parallel execution](parallel.md) (broadcasting or mapping), you of
1214
1315## Creating Joins
1416
15- Create a join using [ ` g.join() ` ] [ pydantic_graph.beta.graph_builder.GraphBuilder.join ] with a reducer type :
17+ Create a join using [ ` g.join() ` ] [ pydantic_graph.beta.graph_builder.GraphBuilder.join ] with a reducer function and initial value or factory :
1618
1719``` python {title="basic_join.py"}
1820from dataclasses import dataclass
1921
20- from pydantic_graph.beta import GraphBuilder, ListAppendReducer, StepContext
22+ from pydantic_graph.beta import GraphBuilder, StepContext
23+ from pydantic_graph.beta.join import reduce_list_append
2124
2225
2326@dataclass
@@ -36,7 +39,7 @@ async def square(ctx: StepContext[SimpleState, None, int]) -> int:
3639 return ctx.inputs * ctx.inputs
3740
3841# Create a join to collect all squared values
39- collect = g.join(ListAppendReducer [int ])
42+ collect = g.join(reduce_list_append, initial_factory = list [int ])
4043
4144g.add(
4245 g.edge_from(g.start_node).to(generate_numbers),
@@ -66,7 +69,8 @@ Pydantic Graph provides several common reducer types out of the box:
6669``` python {title="list_reducer.py"}
6770from dataclasses import dataclass
6871
69- from pydantic_graph.beta import GraphBuilder, ListAppendReducer, StepContext
72+ from pydantic_graph.beta import GraphBuilder, StepContext
73+ from pydantic_graph.beta.join import reduce_list_append
7074
7175
7276@dataclass
@@ -85,7 +89,7 @@ async def main():
8589 async def to_string (ctx : StepContext[SimpleState, None , int ]) -> str :
8690 return f ' value- { ctx.inputs} '
8791
88- collect = g.join(ListAppendReducer [str ])
92+ collect = g.join(reduce_list_append, initial_factory = list [str ])
8993
9094 g.add(
9195 g.edge_from(g.start_node).to(generate),
@@ -109,7 +113,8 @@ _(This example is complete, it can be run "as is" — you'll need to add `import
109113``` python {title="dict_reducer.py"}
110114from dataclasses import dataclass
111115
112- from pydantic_graph.beta import DictUpdateReducer, GraphBuilder, StepContext
116+ from pydantic_graph.beta import GraphBuilder, StepContext
117+ from pydantic_graph.beta.join import reduce_dict_update
113118
114119
115120@dataclass
@@ -128,7 +133,7 @@ async def main():
128133 async def create_entry (ctx : StepContext[SimpleState, None , str ]) -> dict[str , int ]:
129134 return {ctx.inputs: len (ctx.inputs)}
130135
131- merge = g.join(DictUpdateReducer [str , int ])
136+ merge = g.join(reduce_dict_update, initial_factory = dict [str , int ])
132137
133138 g.add(
134139 g.edge_from(g.start_node).to(generate_keys),
@@ -153,7 +158,8 @@ _(This example is complete, it can be run "as is" — you'll need to add `import
153158``` python {title="null_reducer.py"}
154159from dataclasses import dataclass
155160
156- from pydantic_graph.beta import GraphBuilder, NullReducer, StepContext
161+ from pydantic_graph.beta import GraphBuilder, StepContext
162+ from pydantic_graph.beta.join import reduce_null
157163
158164
159165@dataclass
@@ -174,7 +180,7 @@ async def main():
174180 return ctx.inputs
175181
176182 # We don't care about the outputs, only the side effect on state
177- ignore = g.join(NullReducer )
183+ ignore = g.join(reduce_null, initial = None )
178184
179185 @g.step
180186 async def get_total (ctx : StepContext[CounterState, None , None ]) -> int :
@@ -199,46 +205,30 @@ _(This example is complete, it can be run "as is" — you'll need to add `import
199205
200206## Custom Reducers
201207
202- Create custom reducers by subclassing [ ` Reducer ` ] [ pydantic_graph.beta.join.Reducer ] :
208+ Create custom reducers by defining a [ ` ReducerFunction ` ] [ pydantic_graph.beta.join.ReducerFunction ] :
203209
204210``` python {title="custom_reducer.py"}
205- from dataclasses import dataclass
206-
207- from pydantic_graph.beta import GraphBuilder, Reducer, StepContext
208-
209-
210- @dataclass
211- class SimpleState :
212- pass
213-
214211
215- @dataclass (init = False )
216- class SumReducer (Reducer[SimpleState, None , int , int ]):
217- """ Reducer that sums all input values."""
212+ from pydantic_graph.beta import GraphBuilder, StepContext
218213
219- total: int = 0
220-
221- def reduce (self , ctx : StepContext[SimpleState, None , int ]) -> None :
222- """ Called for each input - accumulate the sum."""
223- self .total += ctx.inputs
224214
225- def finalize ( self , ctx : StepContext[SimpleState, None , None ] ) -> int :
226- """ Called after all inputs - return the final result ."""
227- return self .total
215+ def reduce_sum ( current : int , inputs : int ) -> int :
216+ """ A reducer that sums numbers ."""
217+ return current + inputs
228218
229219
230220async def main ():
231- g = GraphBuilder(state_type = SimpleState, output_type = int )
221+ g = GraphBuilder(output_type = int )
232222
233223 @g.step
234- async def generate (ctx : StepContext[SimpleState , None , None ]) -> list[int ]:
224+ async def generate (ctx : StepContext[None , None , None ]) -> list[int ]:
235225 return [5 , 10 , 15 , 20 ]
236226
237227 @g.step
238- async def identity (ctx : StepContext[SimpleState , None , int ]) -> int :
228+ async def identity (ctx : StepContext[None , None , int ]) -> int :
239229 return ctx.inputs
240230
241- sum_join = g.join(SumReducer )
231+ sum_join = g.join(reduce_sum, initial = 0 )
242232
243233 g.add(
244234 g.edge_from(g.start_node).to(generate),
@@ -248,86 +238,95 @@ async def main():
248238 )
249239
250240 graph = g.build()
251- result = await graph.run(state = SimpleState() )
241+ result = await graph.run()
252242 print (result)
253243 # > 50
254244```
255245
256246_ (This example is complete, it can be run "as is" — you'll need to add ` import asyncio; asyncio.run(main()) ` to run ` main ` )_
257247
258- ### Reducer Lifecycle
259-
260- Reducers have two key methods:
261-
262- 1 . ** ` reduce(ctx) ` ** - Called for each input from parallel paths. Use this to accumulate data.
263- 2 . ** ` finalize(ctx) ` ** - Called once after all inputs are received. Return the final aggregated value.
264-
265248## Reducers with State Access
266249
267250Reducers can access and modify the graph state:
268251
269252``` python {title="stateful_reducer.py"}
270253from dataclasses import dataclass
271254
272- from pydantic_graph.beta import GraphBuilder, Reducer, StepContext
255+ from pydantic_graph.beta import GraphBuilder, StepContext
256+ from pydantic_graph.beta.join import ReducerContext
273257
274258
275259@dataclass
276260class MetricsState :
277- items_processed: int = 0
278- sum_total: int = 0
279-
261+ total_count: int = 0
262+ total_sum: int = 0
280263
281- @dataclass (init = False )
282- class MetricsReducer (Reducer[MetricsState, None , int , dict[str , int ]]):
283- """ Reducer that tracks processing metrics in state."""
284264
265+ @dataclass
266+ class ReducedMetrics :
285267 count: int = 0
286- total: int = 0
268+ sum : int = 0
269+
287270
288- def reduce (self , ctx : StepContext[MetricsState, None , int ]) -> None :
289- self .count += 1
290- self .total += ctx.inputs
291- ctx.state.items_processed += 1
292- ctx.state.sum_total += ctx.inputs
271+ def reduce_metrics_sum (ctx : ReducerContext[MetricsState, None ], current : ReducedMetrics, inputs : int ) -> ReducedMetrics:
272+ ctx.state.total_count += 1
273+ ctx.state.total_sum += inputs
274+ return ReducedMetrics(count = current.count + 1 , sum = current.sum + inputs)
293275
294- def finalize (self , ctx : StepContext[MetricsState, None , None ]) -> dict[str , int ]:
295- return {
296- ' count' : self .count,
297- ' total' : self .total,
298- }
276+ def reduce_metrics_max (current : ReducedMetrics, inputs : ReducedMetrics) -> ReducedMetrics:
277+ return ReducedMetrics(count = max (current.count, inputs.count), sum = max (current.sum, inputs.sum))
299278
300279
301280async def main ():
302281 g = GraphBuilder(state_type = MetricsState, output_type = dict[str , int ])
303282
304283 @g.step
305- async def generate (ctx : StepContext[MetricsState , None , None ]) -> list[int ]:
306- return [10 , 20 , 30 , 40 ]
284+ async def generate (ctx : StepContext[object , None , None ]) -> list[int ]:
285+ return [1 , 3 , 5 , 7 , 9 , 10 , 20 , 30 , 40 ]
307286
308287 @g.step
309- async def process (ctx : StepContext[MetricsState, None , int ]) -> int :
288+ async def process_even (ctx : StepContext[MetricsState, None , int ]) -> int :
310289 return ctx.inputs * 2
311290
312- metrics = g.join(MetricsReducer)
291+ @g.step
292+ async def process_odd (ctx : StepContext[MetricsState, None , int ]) -> int :
293+ return ctx.inputs * 3
294+
295+ metrics_even = g.join(reduce_metrics_sum, initial_factory = ReducedMetrics, node_id = ' metrics_even' )
296+ metrics_odd = g.join(reduce_metrics_sum, initial_factory = ReducedMetrics, node_id = ' metrics_odd' )
297+ metrics_max = g.join(reduce_metrics_max, initial_factory = ReducedMetrics, node_id = ' metrics_max' )
313298
314299 g.add(
315300 g.edge_from(g.start_node).to(generate),
316- g.edge_from(generate).map().to(process),
317- g.edge_from(process).to(metrics),
318- g.edge_from(metrics).to(g.end_node),
301+ # Send even and odd numbers to their respective `process` steps
302+ g.edge_from(generate).map().to(
303+ g.decision()
304+ .branch(g.match(int , matches = lambda x : x % 2 == 0 ).label(' even' ).to(process_even))
305+ .branch(g.match(int , matches = lambda x : x % 2 == 1 ).label(' odd' ).to(process_odd))
306+ ),
307+ # Reduce metrics for even and odd numbers separately
308+ g.edge_from(process_even).to(metrics_even),
309+ g.edge_from(process_odd).to(metrics_odd),
310+ # Aggregate the max values for each field
311+ g.edge_from(metrics_even).to(metrics_max),
312+ g.edge_from(metrics_odd).to(metrics_max),
313+ # Finish the graph run with the final reduced value
314+ g.edge_from(metrics_max).to(g.end_node),
319315 )
320316
321317 graph = g.build()
322318 state = MetricsState()
323319 result = await graph.run(state = state)
324320
325321 print (f ' Result: { result} ' )
326- # > Result: {'count': 4, 'total': 200}
327- print (f ' State items_processed: { state.items_processed} ' )
328- # > State items_processed: 4
329- print (f ' State sum_total: { state.sum_total} ' )
330- # > State sum_total: 200
322+ # > Result: ReducedMetrics(count=5, sum=200)
323+ # > Result: {'count': 4, 'total': 200}
324+ print (f ' State total_count: { state.total_count} ' )
325+ # > State total_count: 9
326+ # > State items_processed: 4
327+ print (f ' State total_sum: { state.total_sum} ' )
328+ # > State total_sum: 275
329+ # > State sum_total: 200
331330```
332331
333332_ (This example is complete, it can be run "as is" — you'll need to add ` import asyncio; asyncio.run(main()) ` to run ` main ` )_
@@ -339,7 +338,8 @@ A graph can have multiple independent joins:
339338``` python {title="multiple_joins.py"}
340339from dataclasses import dataclass, field
341340
342- from pydantic_graph.beta import GraphBuilder, ListAppendReducer, StepContext
341+ from pydantic_graph.beta import GraphBuilder, StepContext
342+ from pydantic_graph.beta.join import reduce_list_append
343343
344344
345345@dataclass
@@ -366,8 +366,8 @@ async def main():
366366 async def process_b (ctx : StepContext[MultiState, None , int ]) -> int :
367367 return ctx.inputs * 3
368368
369- join_a = g.join(ListAppendReducer [int ], node_id = ' join_a' )
370- join_b = g.join(ListAppendReducer [int ], node_id = ' join_b' )
369+ join_a = g.join(reduce_list_append, initial_factory = list [int ], node_id = ' join_a' )
370+ join_b = g.join(reduce_list_append, initial_factory = list [int ], node_id = ' join_b' )
371371
372372 @g.step
373373 async def store_a (ctx : StepContext[MultiState, None , list[int ]]) -> None :
@@ -412,11 +412,11 @@ _(This example is complete, it can be run "as is" — you'll need to add `import
412412Like steps, joins can have custom IDs:
413413
414414``` python {title="join_custom_id.py" requires="basic_join.py"}
415- from pydantic_graph.beta import ListAppendReducer
415+ from pydantic_graph.beta.join import reduce_list_append
416416
417417from basic_join import g
418418
419- my_join = g.join(ListAppendReducer [int ], node_id = ' my_custom_join_id' )
419+ my_join = g.join(reduce_list_append, initial_factory = list [int ], node_id = ' my_custom_join_id' )
420420```
421421
422422## How Joins Work
0 commit comments