Skip to content

Commit ca3b70e

Browse files
committed
Get tests passing
1 parent 2ed5bab commit ca3b70e

File tree

3 files changed

+104
-95
lines changed

3 files changed

+104
-95
lines changed

docs/graph/beta/index.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,8 @@ Here's an example showcasing parallel execution with a map operation:
123123
```python {title="parallel_processing.py"}
124124
from dataclasses import dataclass
125125

126-
from pydantic_graph.beta import GraphBuilder, ListAppendReducer, StepContext
126+
from pydantic_graph.beta import GraphBuilder, StepContext
127+
from pydantic_graph.beta.join import reduce_list_append
127128

128129

129130
@dataclass
@@ -147,7 +148,7 @@ async def main():
147148
return ctx.inputs * ctx.inputs
148149

149150
# Create a join to collect results
150-
collect_results = g.join(ListAppendReducer[int])
151+
collect_results = g.join(reduce_list_append, initial_factory=list[int])
151152

152153
# Build the graph with map operation
153154
g.add(

docs/graph/beta/joins.md

Lines changed: 77 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from mypyc.primitives.float_ops import int_to_float_op
2+
13
# Joins and Reducers
24

35
Join nodes synchronize and aggregate data from parallel execution paths. They use **Reducers** to combine multiple inputs into a single output.
@@ -12,12 +14,13 @@ When you use [parallel execution](parallel.md) (broadcasting or mapping), you of
1214

1315
## Creating Joins
1416

15-
Create a join using [`g.join()`][pydantic_graph.beta.graph_builder.GraphBuilder.join] with a reducer type:
17+
Create a join using [`g.join()`][pydantic_graph.beta.graph_builder.GraphBuilder.join] with a reducer function and initial value or factory:
1618

1719
```python {title="basic_join.py"}
1820
from dataclasses import dataclass
1921

20-
from pydantic_graph.beta import GraphBuilder, ListAppendReducer, StepContext
22+
from pydantic_graph.beta import GraphBuilder, StepContext
23+
from pydantic_graph.beta.join import reduce_list_append
2124

2225

2326
@dataclass
@@ -36,7 +39,7 @@ async def square(ctx: StepContext[SimpleState, None, int]) -> int:
3639
return ctx.inputs * ctx.inputs
3740

3841
# Create a join to collect all squared values
39-
collect = g.join(ListAppendReducer[int])
42+
collect = g.join(reduce_list_append, initial_factory=list[int])
4043

4144
g.add(
4245
g.edge_from(g.start_node).to(generate_numbers),
@@ -66,7 +69,8 @@ Pydantic Graph provides several common reducer types out of the box:
6669
```python {title="list_reducer.py"}
6770
from dataclasses import dataclass
6871

69-
from pydantic_graph.beta import GraphBuilder, ListAppendReducer, StepContext
72+
from pydantic_graph.beta import GraphBuilder, StepContext
73+
from pydantic_graph.beta.join import reduce_list_append
7074

7175

7276
@dataclass
@@ -85,7 +89,7 @@ async def main():
8589
async def to_string(ctx: StepContext[SimpleState, None, int]) -> str:
8690
return f'value-{ctx.inputs}'
8791

88-
collect = g.join(ListAppendReducer[str])
92+
collect = g.join(reduce_list_append, initial_factory=list[str])
8993

9094
g.add(
9195
g.edge_from(g.start_node).to(generate),
@@ -109,7 +113,8 @@ _(This example is complete, it can be run "as is" — you'll need to add `import
109113
```python {title="dict_reducer.py"}
110114
from dataclasses import dataclass
111115

112-
from pydantic_graph.beta import DictUpdateReducer, GraphBuilder, StepContext
116+
from pydantic_graph.beta import GraphBuilder, StepContext
117+
from pydantic_graph.beta.join import reduce_dict_update
113118

114119

115120
@dataclass
@@ -128,7 +133,7 @@ async def main():
128133
async def create_entry(ctx: StepContext[SimpleState, None, str]) -> dict[str, int]:
129134
return {ctx.inputs: len(ctx.inputs)}
130135

131-
merge = g.join(DictUpdateReducer[str, int])
136+
merge = g.join(reduce_dict_update, initial_factory=dict[str, int])
132137

133138
g.add(
134139
g.edge_from(g.start_node).to(generate_keys),
@@ -153,7 +158,8 @@ _(This example is complete, it can be run "as is" — you'll need to add `import
153158
```python {title="null_reducer.py"}
154159
from dataclasses import dataclass
155160

156-
from pydantic_graph.beta import GraphBuilder, NullReducer, StepContext
161+
from pydantic_graph.beta import GraphBuilder, StepContext
162+
from pydantic_graph.beta.join import reduce_null
157163

158164

159165
@dataclass
@@ -174,7 +180,7 @@ async def main():
174180
return ctx.inputs
175181

176182
# We don't care about the outputs, only the side effect on state
177-
ignore = g.join(NullReducer)
183+
ignore = g.join(reduce_null, initial=None)
178184

179185
@g.step
180186
async def get_total(ctx: StepContext[CounterState, None, None]) -> int:
@@ -199,46 +205,30 @@ _(This example is complete, it can be run "as is" — you'll need to add `import
199205

200206
## Custom Reducers
201207

202-
Create custom reducers by subclassing [`Reducer`][pydantic_graph.beta.join.Reducer]:
208+
Create custom reducers by defining a [`ReducerFunction`][pydantic_graph.beta.join.ReducerFunction]:
203209

204210
```python {title="custom_reducer.py"}
205-
from dataclasses import dataclass
206-
207-
from pydantic_graph.beta import GraphBuilder, Reducer, StepContext
208-
209-
210-
@dataclass
211-
class SimpleState:
212-
pass
213-
214211

215-
@dataclass(init=False)
216-
class SumReducer(Reducer[SimpleState, None, int, int]):
217-
"""Reducer that sums all input values."""
212+
from pydantic_graph.beta import GraphBuilder, StepContext
218213

219-
total: int = 0
220-
221-
def reduce(self, ctx: StepContext[SimpleState, None, int]) -> None:
222-
"""Called for each input - accumulate the sum."""
223-
self.total += ctx.inputs
224214

225-
def finalize(self, ctx: StepContext[SimpleState, None, None]) -> int:
226-
"""Called after all inputs - return the final result."""
227-
return self.total
215+
def reduce_sum(current: int, inputs: int) -> int:
216+
"""A reducer that sums numbers."""
217+
return current + inputs
228218

229219

230220
async def main():
231-
g = GraphBuilder(state_type=SimpleState, output_type=int)
221+
g = GraphBuilder(output_type=int)
232222

233223
@g.step
234-
async def generate(ctx: StepContext[SimpleState, None, None]) -> list[int]:
224+
async def generate(ctx: StepContext[None, None, None]) -> list[int]:
235225
return [5, 10, 15, 20]
236226

237227
@g.step
238-
async def identity(ctx: StepContext[SimpleState, None, int]) -> int:
228+
async def identity(ctx: StepContext[None, None, int]) -> int:
239229
return ctx.inputs
240230

241-
sum_join = g.join(SumReducer)
231+
sum_join = g.join(reduce_sum, initial=0)
242232

243233
g.add(
244234
g.edge_from(g.start_node).to(generate),
@@ -248,86 +238,95 @@ async def main():
248238
)
249239

250240
graph = g.build()
251-
result = await graph.run(state=SimpleState())
241+
result = await graph.run()
252242
print(result)
253243
#> 50
254244
```
255245

256246
_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_
257247

258-
### Reducer Lifecycle
259-
260-
Reducers have two key methods:
261-
262-
1. **`reduce(ctx)`** - Called for each input from parallel paths. Use this to accumulate data.
263-
2. **`finalize(ctx)`** - Called once after all inputs are received. Return the final aggregated value.
264-
265248
## Reducers with State Access
266249

267250
Reducers can access and modify the graph state:
268251

269252
```python {title="stateful_reducer.py"}
270253
from dataclasses import dataclass
271254

272-
from pydantic_graph.beta import GraphBuilder, Reducer, StepContext
255+
from pydantic_graph.beta import GraphBuilder, StepContext
256+
from pydantic_graph.beta.join import ReducerContext
273257

274258

275259
@dataclass
276260
class MetricsState:
277-
items_processed: int = 0
278-
sum_total: int = 0
279-
261+
total_count: int = 0
262+
total_sum: int = 0
280263

281-
@dataclass(init=False)
282-
class MetricsReducer(Reducer[MetricsState, None, int, dict[str, int]]):
283-
"""Reducer that tracks processing metrics in state."""
284264

265+
@dataclass
266+
class ReducedMetrics:
285267
count: int = 0
286-
total: int = 0
268+
sum: int = 0
269+
287270

288-
def reduce(self, ctx: StepContext[MetricsState, None, int]) -> None:
289-
self.count += 1
290-
self.total += ctx.inputs
291-
ctx.state.items_processed += 1
292-
ctx.state.sum_total += ctx.inputs
271+
def reduce_metrics_sum(ctx: ReducerContext[MetricsState, None], current: ReducedMetrics, inputs: int) -> ReducedMetrics:
272+
ctx.state.total_count += 1
273+
ctx.state.total_sum += inputs
274+
return ReducedMetrics(count=current.count + 1, sum=current.sum + inputs)
293275

294-
def finalize(self, ctx: StepContext[MetricsState, None, None]) -> dict[str, int]:
295-
return {
296-
'count': self.count,
297-
'total': self.total,
298-
}
276+
def reduce_metrics_max(current: ReducedMetrics, inputs: ReducedMetrics) -> ReducedMetrics:
277+
return ReducedMetrics(count=max(current.count, inputs.count), sum=max(current.sum, inputs.sum))
299278

300279

301280
async def main():
302281
g = GraphBuilder(state_type=MetricsState, output_type=dict[str, int])
303282

304283
@g.step
305-
async def generate(ctx: StepContext[MetricsState, None, None]) -> list[int]:
306-
return [10, 20, 30, 40]
284+
async def generate(ctx: StepContext[object, None, None]) -> list[int]:
285+
return [1, 3, 5, 7, 9, 10, 20, 30, 40]
307286

308287
@g.step
309-
async def process(ctx: StepContext[MetricsState, None, int]) -> int:
288+
async def process_even(ctx: StepContext[MetricsState, None, int]) -> int:
310289
return ctx.inputs * 2
311290

312-
metrics = g.join(MetricsReducer)
291+
@g.step
292+
async def process_odd(ctx: StepContext[MetricsState, None, int]) -> int:
293+
return ctx.inputs * 3
294+
295+
metrics_even = g.join(reduce_metrics_sum, initial_factory=ReducedMetrics, node_id='metrics_even')
296+
metrics_odd = g.join(reduce_metrics_sum, initial_factory=ReducedMetrics, node_id='metrics_odd')
297+
metrics_max = g.join(reduce_metrics_max, initial_factory=ReducedMetrics, node_id='metrics_max')
313298

314299
g.add(
315300
g.edge_from(g.start_node).to(generate),
316-
g.edge_from(generate).map().to(process),
317-
g.edge_from(process).to(metrics),
318-
g.edge_from(metrics).to(g.end_node),
301+
# Send even and odd numbers to their respective `process` steps
302+
g.edge_from(generate).map().to(
303+
g.decision()
304+
.branch(g.match(int, matches=lambda x: x % 2 == 0).label('even').to(process_even))
305+
.branch(g.match(int, matches=lambda x: x % 2 == 1).label('odd').to(process_odd))
306+
),
307+
# Reduce metrics for even and odd numbers separately
308+
g.edge_from(process_even).to(metrics_even),
309+
g.edge_from(process_odd).to(metrics_odd),
310+
# Aggregate the max values for each field
311+
g.edge_from(metrics_even).to(metrics_max),
312+
g.edge_from(metrics_odd).to(metrics_max),
313+
# Finish the graph run with the final reduced value
314+
g.edge_from(metrics_max).to(g.end_node),
319315
)
320316

321317
graph = g.build()
322318
state = MetricsState()
323319
result = await graph.run(state=state)
324320

325321
print(f'Result: {result}')
326-
#> Result: {'count': 4, 'total': 200}
327-
print(f'State items_processed: {state.items_processed}')
328-
#> State items_processed: 4
329-
print(f'State sum_total: {state.sum_total}')
330-
#> State sum_total: 200
322+
#> Result: ReducedMetrics(count=5, sum=200)
323+
# > Result: {'count': 4, 'total': 200}
324+
print(f'State total_count: {state.total_count}')
325+
#> State total_count: 9
326+
# > State items_processed: 4
327+
print(f'State total_sum: {state.total_sum}')
328+
#> State total_sum: 275
329+
# > State sum_total: 200
331330
```
332331

333332
_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_
@@ -339,7 +338,8 @@ A graph can have multiple independent joins:
339338
```python {title="multiple_joins.py"}
340339
from dataclasses import dataclass, field
341340

342-
from pydantic_graph.beta import GraphBuilder, ListAppendReducer, StepContext
341+
from pydantic_graph.beta import GraphBuilder, StepContext
342+
from pydantic_graph.beta.join import reduce_list_append
343343

344344

345345
@dataclass
@@ -366,8 +366,8 @@ async def main():
366366
async def process_b(ctx: StepContext[MultiState, None, int]) -> int:
367367
return ctx.inputs * 3
368368

369-
join_a = g.join(ListAppendReducer[int], node_id='join_a')
370-
join_b = g.join(ListAppendReducer[int], node_id='join_b')
369+
join_a = g.join(reduce_list_append, initial_factory=list[int], node_id='join_a')
370+
join_b = g.join(reduce_list_append, initial_factory=list[int], node_id='join_b')
371371

372372
@g.step
373373
async def store_a(ctx: StepContext[MultiState, None, list[int]]) -> None:
@@ -412,11 +412,11 @@ _(This example is complete, it can be run "as is" — you'll need to add `import
412412
Like steps, joins can have custom IDs:
413413

414414
```python {title="join_custom_id.py" requires="basic_join.py"}
415-
from pydantic_graph.beta import ListAppendReducer
415+
from pydantic_graph.beta.join import reduce_list_append
416416

417417
from basic_join import g
418418

419-
my_join = g.join(ListAppendReducer[int], node_id='my_custom_join_id')
419+
my_join = g.join(reduce_list_append, initial_factory=list[int], node_id='my_custom_join_id')
420420
```
421421

422422
## How Joins Work

0 commit comments

Comments
 (0)