Skip to content

Commit d38464e

Browse files
committed
Fix failing docs tests
1 parent 9fc966f commit d38464e

File tree

2 files changed

+48
-30
lines changed

2 files changed

+48
-30
lines changed

docs/graph/beta/joins.md

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,28 +25,29 @@ class SimpleState:
2525
pass
2626

2727

28-
async def main():
29-
g = GraphBuilder(state_type=SimpleState, output_type=list[int])
28+
g = GraphBuilder(state_type=SimpleState, output_type=list[int])
3029

31-
@g.step
32-
async def generate_numbers(ctx: StepContext[SimpleState, None, None]) -> list[int]:
33-
return [1, 2, 3, 4, 5]
30+
@g.step
31+
async def generate_numbers(ctx: StepContext[SimpleState, None, None]) -> list[int]:
32+
return [1, 2, 3, 4, 5]
3433

35-
@g.step
36-
async def square(ctx: StepContext[SimpleState, None, int]) -> int:
37-
return ctx.inputs * ctx.inputs
34+
@g.step
35+
async def square(ctx: StepContext[SimpleState, None, int]) -> int:
36+
return ctx.inputs * ctx.inputs
3837

39-
# Create a join to collect all squared values
40-
collect = g.join(ListReducer[int])
38+
# Create a join to collect all squared values
39+
collect = g.join(ListReducer[int])
4140

42-
g.add(
43-
g.edge_from(g.start_node).to(generate_numbers),
44-
g.edge_from(generate_numbers).map().to(square),
45-
g.edge_from(square).to(collect),
46-
g.edge_from(collect).to(g.end_node),
47-
)
41+
g.add(
42+
g.edge_from(g.start_node).to(generate_numbers),
43+
g.edge_from(generate_numbers).map().to(square),
44+
g.edge_from(square).to(collect),
45+
g.edge_from(collect).to(g.end_node),
46+
)
4847

49-
graph = g.build()
48+
graph = g.build()
49+
50+
async def main():
5051
result = await graph.run(state=SimpleState())
5152
print(sorted(result))
5253
#> [1, 4, 9, 16, 25]
@@ -139,7 +140,7 @@ async def main():
139140
graph = g.build()
140141
result = await graph.run(state=SimpleState())
141142
print(result)
142-
#> {'apple': 5, 'banana': 6, 'cherry': 6}
143+
#> {'cherry': 6, 'banana': 6, 'apple': 5}
143144
```
144145

145146
_(This example is complete, it can be run "as is" — you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_
@@ -410,7 +411,7 @@ _(This example is complete, it can be run "as is" — you'll need to add `import
410411
Like steps, joins can have custom IDs:
411412

412413
```python {title="join_custom_id.py" requires="basic_join.py"}
413-
from basic_join import g, ListReducer
414+
from basic_join import ListReducer, g
414415

415416
my_join = g.join(ListReducer[int], node_id='my_custom_join_id')
416417
```

docs/graph/beta/steps.md

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,21 @@ from pydantic_graph.beta import GraphBuilder, StepContext
1616
class MyState:
1717
counter: int = 0
1818

19+
g = GraphBuilder(state_type=MyState, output_type=int)
1920

20-
async def main():
21-
g = GraphBuilder(state_type=MyState, output_type=int)
21+
@g.step
22+
async def increment(ctx: StepContext[MyState, None, None]) -> int:
23+
ctx.state.counter += 1
24+
return ctx.state.counter
2225

23-
@g.step
24-
async def increment(ctx: StepContext[MyState, None, None]) -> int:
25-
ctx.state.counter += 1
26-
return ctx.state.counter
26+
g.add(
27+
g.edge_from(g.start_node).to(increment),
28+
g.edge_from(increment).to(g.end_node),
29+
)
2730

28-
g.add(
29-
g.edge_from(g.start_node).to(increment),
30-
g.edge_from(increment).to(g.end_node),
31-
)
31+
graph = g.build()
3232

33-
graph = g.build()
33+
async def main():
3434
state = MyState()
3535
result = await graph.run(state=state)
3636
print(result)
@@ -195,8 +195,11 @@ _(This example is complete, it can be run "as is" — you'll need to add `import
195195
By default, step node IDs are inferred from the function name. You can override this:
196196

197197
```python {title="custom_id.py" requires="basic_step.py"}
198+
from pydantic_graph.beta import StepContext
199+
198200
from basic_step import MyState, g
199201

202+
200203
@g.step(node_id='my_custom_id')
201204
async def my_step(ctx: StepContext[MyState, None, None]) -> int:
202205
return 42
@@ -209,8 +212,11 @@ async def my_step(ctx: StepContext[MyState, None, None]) -> int:
209212
Labels provide documentation for diagram generation:
210213

211214
```python {title="labels.py" requires="basic_step.py"}
215+
from pydantic_graph.beta import StepContext
216+
212217
from basic_step import MyState, g
213218

219+
214220
@g.step(label='Increment the counter')
215221
async def increment(ctx: StepContext[MyState, None, None]) -> int:
216222
ctx.state.counter += 1
@@ -330,6 +336,17 @@ The beta graph API provides strong type checking through generics. Type paramete
330336
- Input/output types match across edges
331337

332338
```python
339+
from dataclasses import dataclass
340+
341+
from pydantic_graph.beta import GraphBuilder, StepContext
342+
343+
344+
@dataclass
345+
class MyState:
346+
pass
347+
348+
g = GraphBuilder(state_type=MyState, output_type=str)
349+
333350
# Type checker will catch mismatches
334351
@g.step
335352
async def expects_int(ctx: StepContext[MyState, None, int]) -> str:

0 commit comments

Comments
 (0)