|
10 | 10 | from pydantic_graph.beta.graph_builder import GraphBuildingError |
11 | 11 | from pydantic_graph.beta.join import reduce_list_append, reduce_sum |
12 | 12 | from pydantic_graph.beta.node import Fork |
| 13 | +from pydantic_graph.exceptions import GraphValidationError |
13 | 14 |
|
14 | 15 | pytestmark = pytest.mark.anyio |
15 | 16 |
|
@@ -326,3 +327,118 @@ async def source(ctx: StepContext[None, None, int]) -> list[int]: |
326 | 327 | match='For every Join J in the graph, there must be a Fork F between the StartNode and J satisfying', |
327 | 328 | ): |
328 | 329 | g.build() |
| 330 | + |
| 331 | + |
| 332 | +async def test_validation_no_edges_from_start(): |
| 333 | + """Test that validation catches graphs with no edges from start node.""" |
| 334 | + g = GraphBuilder(output_type=int) |
| 335 | + |
| 336 | + @g.step |
| 337 | + async def orphan_step(ctx: StepContext[None, None, None]) -> int: |
| 338 | + return 42 # pragma: no cover |
| 339 | + |
| 340 | + # Add the step to the graph but don't connect it to start |
| 341 | + g.add(g.edge_from(orphan_step).to(g.end_node)) |
| 342 | + |
| 343 | + with pytest.raises(GraphValidationError, match='The graph has no edges from the start node'): |
| 344 | + g.build() |
| 345 | + |
| 346 | + |
| 347 | +async def test_validation_no_edges_to_end(): |
| 348 | + """Test that validation catches graphs with no edges to end node.""" |
| 349 | + g = GraphBuilder(output_type=int) |
| 350 | + |
| 351 | + @g.step |
| 352 | + async def dead_end_step(ctx: StepContext[None, None, None]) -> int: |
| 353 | + return 42 # pragma: no cover |
| 354 | + |
| 355 | + # Connect start to step but don't connect step to end |
| 356 | + g.add(g.edge_from(g.start_node).to(dead_end_step)) |
| 357 | + |
| 358 | + with pytest.raises(GraphValidationError, match='The graph has no edges to the end node'): |
| 359 | + g.build() |
| 360 | + |
| 361 | + |
| 362 | +async def test_validation_node_with_no_outgoing_edges(): |
| 363 | + """Test that validation catches nodes with no outgoing edges.""" |
| 364 | + g = GraphBuilder(output_type=int) |
| 365 | + |
| 366 | + @g.step |
| 367 | + async def first_step(ctx: StepContext[None, None, None]) -> int: |
| 368 | + return 42 # pragma: no cover |
| 369 | + |
| 370 | + @g.step |
| 371 | + async def dead_end_step(ctx: StepContext[None, None, int]) -> int: |
| 372 | + return ctx.inputs # pragma: no cover |
| 373 | + |
| 374 | + # first_step connects to both dead_end_step and end_node |
| 375 | + # But dead_end_step has no outgoing edges |
| 376 | + g.add( |
| 377 | + g.edge_from(g.start_node).to(first_step), |
| 378 | + g.edge_from(first_step).to(dead_end_step, g.end_node), |
| 379 | + ) |
| 380 | + |
| 381 | + with pytest.raises(GraphValidationError, match='The following nodes have no outgoing edges'): |
| 382 | + g.build() |
| 383 | + |
| 384 | + |
| 385 | +async def test_validation_end_node_unreachable(): |
| 386 | + """Test that validation catches when end node is unreachable from start.""" |
| 387 | + g = GraphBuilder(input_type=int, output_type=int) |
| 388 | + |
| 389 | + @g.step |
| 390 | + async def first_step(ctx: StepContext[None, None, int]) -> int: |
| 391 | + return 42 # pragma: no cover |
| 392 | + |
| 393 | + @g.step |
| 394 | + async def second_step(ctx: StepContext[None, None, int]) -> int: |
| 395 | + return ctx.inputs # pragma: no cover |
| 396 | + |
| 397 | + # Create a cycle that doesn't reach the end node |
| 398 | + g.add( |
| 399 | + g.edge_from(g.start_node).to(first_step), |
| 400 | + g.edge_from(first_step).to(second_step), |
| 401 | + g.edge_from(second_step).to(first_step), |
| 402 | + ) |
| 403 | + |
| 404 | + with pytest.raises(GraphValidationError, match='The graph has no edges to the end node'): |
| 405 | + g.build() |
| 406 | + |
| 407 | + |
| 408 | +async def test_validation_unreachable_nodes(): |
| 409 | + """Test that validation catches nodes that are not reachable from start.""" |
| 410 | + g = GraphBuilder(output_type=int) |
| 411 | + |
| 412 | + @g.step |
| 413 | + async def reachable_step(ctx: StepContext[None, None, None]) -> int: |
| 414 | + return 10 |
| 415 | + |
| 416 | + @g.step |
| 417 | + async def unreachable_step(ctx: StepContext[None, None, int]) -> int: |
| 418 | + return ctx.inputs * 2 # pragma: no cover |
| 419 | + |
| 420 | + # unreachable_step is in the graph but not connected from start |
| 421 | + g.add( |
| 422 | + g.edge_from(g.start_node).to(reachable_step), |
| 423 | + g.edge_from(reachable_step).to(g.end_node), |
| 424 | + g.edge_from(unreachable_step).to(g.end_node), |
| 425 | + ) |
| 426 | + |
| 427 | + with pytest.raises(GraphValidationError, match='The following nodes are not reachable from the start node'): |
| 428 | + g.build() |
| 429 | + |
| 430 | + |
| 431 | +async def test_validation_can_be_disabled(): |
| 432 | + """Test that validation can be disabled with validate_graph_structure=False.""" |
| 433 | + g = GraphBuilder(output_type=int) |
| 434 | + |
| 435 | + @g.step |
| 436 | + async def orphan_step(ctx: StepContext[None, None, None]) -> int: |
| 437 | + return 42 # pragma: no cover |
| 438 | + |
| 439 | + # Add the step to the graph but don't connect it to start |
| 440 | + # This would normally fail validation |
| 441 | + g.add(g.edge_from(orphan_step).to(g.end_node)) |
| 442 | + |
| 443 | + # Should not raise an error when validation is disabled |
| 444 | + g.build(validate_graph_structure=False) |
0 commit comments