|
1 | 1 | from operator import attrgetter
|
| 2 | +from pathlib import Path |
2 | 3 | from copy import copy
|
3 | 4 | from unittest.mock import Mock
|
4 | 5 | import pytest
|
|
15 | 16 |
|
16 | 17 |
|
17 | 18 | @python.define
|
18 |
| -def Add(a, b): |
| 19 | +def Add(a: int | float, b: int | float) -> int | float: |
19 | 20 | return a + b
|
20 | 21 |
|
21 | 22 |
|
22 | 23 | @python.define
|
23 |
| -def Mul(a, b): |
| 24 | +def Mul(a: int | float, b: int | float) -> int | float: |
24 | 25 | return a * b
|
25 | 26 |
|
26 | 27 |
|
27 | 28 | @python.define(outputs=["divided"])
|
28 |
| -def Divide(x, y): |
| 29 | +def Divide(x: int | float, y: int | float) -> float: |
29 | 30 | return x / y
|
30 | 31 |
|
31 | 32 |
|
@@ -68,7 +69,9 @@ def MyTestWorkflow(a, b):
|
68 | 69 | wf = Workflow.construct(workflow_spec)
|
69 | 70 | assert wf.inputs.a == 1
|
70 | 71 | assert wf.inputs.b == 2.0
|
71 |
| - assert wf.outputs.out == LazyOutField(node=wf["Mul"], field="out", type=ty.Any) |
| 72 | + assert wf.outputs.out == LazyOutField( |
| 73 | + node=wf["Mul"], field="out", type=int | float, type_checked=True |
| 74 | + ) |
72 | 75 |
|
73 | 76 | # Nodes are named after the specs by default
|
74 | 77 | assert list(wf.node_names) == ["Add", "Mul"]
|
@@ -185,7 +188,9 @@ class Outputs(workflow.Outputs):
|
185 | 188 | wf = Workflow.construct(workflow_spec)
|
186 | 189 | assert wf.inputs.a == 1
|
187 | 190 | assert wf.inputs.b == 2.0
|
188 |
| - assert wf.outputs.out == LazyOutField(node=wf["Mul"], field="out", type=ty.Any) |
| 191 | + assert wf.outputs.out == LazyOutField( |
| 192 | + node=wf["Mul"], field="out", type=int | float, type_checked=True |
| 193 | + ) |
189 | 194 |
|
190 | 195 | # Nodes are named after the specs by default
|
191 | 196 | assert list(wf.node_names) == ["Add", "Mul"]
|
@@ -323,7 +328,7 @@ def MyTestWorkflow(a: int, b: float) -> tuple[float, float]:
|
323 | 328 | node=wf["Mul"], field="out", type=float, type_checked=True
|
324 | 329 | )
|
325 | 330 | assert wf.outputs.out2 == LazyOutField(
|
326 |
| - node=wf["division"], field="divided", type=ty.Any |
| 331 | + node=wf["division"], field="divided", type=float, type_checked=True |
327 | 332 | )
|
328 | 333 | assert list(wf.node_names) == ["addition", "Mul", "division"]
|
329 | 334 |
|
@@ -362,8 +367,12 @@ def MyTestWorkflow(a: int, b: float):
|
362 | 367 | wf = Workflow.construct(workflow_spec)
|
363 | 368 | assert wf.inputs.a == 1
|
364 | 369 | assert wf.inputs.b == 2.0
|
365 |
| - assert wf.outputs.out1 == LazyOutField(node=wf["Mul"], field="out", type=ty.Any) |
366 |
| - assert wf.outputs.out2 == LazyOutField(node=wf["Add"], field="out", type=ty.Any) |
| 370 | + assert wf.outputs.out1 == LazyOutField( |
| 371 | + node=wf["Mul"], field="out", type=int | float, type_checked=True |
| 372 | + ) |
| 373 | + assert wf.outputs.out2 == LazyOutField( |
| 374 | + node=wf["Add"], field="out", type=int | float, type_checked=True |
| 375 | + ) |
367 | 376 | assert list(wf.node_names) == ["Add", "Mul"]
|
368 | 377 |
|
369 | 378 |
|
@@ -500,3 +509,68 @@ def RecursiveNestedWorkflow(a: float, depth: int) -> float:
|
500 | 509 | type=float,
|
501 | 510 | type_checked=True,
|
502 | 511 | )
|
| 512 | + |
| 513 | + |
| 514 | +def test_workflow_lzout_inputs1(tmp_path: Path): |
| 515 | + |
| 516 | + @workflow.define |
| 517 | + def InputAccessWorkflow(a, b, c): |
| 518 | + add = workflow.add(Add(a=a, b=b)) |
| 519 | + add.inputs.a = c |
| 520 | + mul = workflow.add(Mul(a=add.out, b=b)) |
| 521 | + return mul.out |
| 522 | + |
| 523 | + input_access_workflow = InputAccessWorkflow(a=1, b=2.0, c=3.0) |
| 524 | + outputs = input_access_workflow(cache_root=tmp_path) |
| 525 | + assert outputs.out == 10.0 |
| 526 | + |
| 527 | + |
| 528 | +def test_workflow_lzout_inputs2(tmp_path: Path): |
| 529 | + |
| 530 | + @workflow.define |
| 531 | + def InputAccessWorkflow(a, b, c): |
| 532 | + add = workflow.add(Add(a=a, b=b)) |
| 533 | + add.inputs.a = c |
| 534 | + mul = workflow.add(Mul(a=add.out, b=b)) |
| 535 | + return mul.out |
| 536 | + |
| 537 | + input_access_workflow = InputAccessWorkflow(a=1, b=2.0, c=3.0) |
| 538 | + outputs = input_access_workflow(cache_root=tmp_path) |
| 539 | + assert outputs.out == 10.0 |
| 540 | + |
| 541 | + |
| 542 | +def test_workflow_lzout_inputs2(tmp_path: Path): |
| 543 | + """Set the inputs of the 'add' node after its outputs have been accessed |
| 544 | + but the state has not been altered""" |
| 545 | + |
| 546 | + @workflow.define |
| 547 | + def InputAccessWorkflow(a, b, c): |
| 548 | + add = workflow.add(Add(a=a, b=b)) |
| 549 | + mul = workflow.add(Mul(a=add.out, b=b)) |
| 550 | + add.inputs.a = c |
| 551 | + return mul.out |
| 552 | + |
| 553 | + input_access_workflow = InputAccessWorkflow(a=1, b=2.0, c=3.0) |
| 554 | + outputs = input_access_workflow(cache_root=tmp_path) |
| 555 | + assert outputs.out == 10.0 |
| 556 | + |
| 557 | + |
| 558 | +def test_workflow_lzout_inputs_state_change_fail(tmp_path: Path): |
| 559 | + """Set the inputs of the 'mul' node after its outputs have been accessed |
| 560 | + with an upstream lazy field that has a different state than the original. |
| 561 | + This changes the type of the input and is therefore not permitted""" |
| 562 | + |
| 563 | + @workflow.define |
| 564 | + def InputAccessWorkflow(a, b, c): |
| 565 | + add1 = workflow.add(Add(a=a, b=b), name="add1") |
| 566 | + add2 = workflow.add(Add(a=a).split(b=c), name="add2") |
| 567 | + mul1 = workflow.add(Mul(a=add1.out, b=b), name="mul1") |
| 568 | + mul2 = workflow.add(Mul(a=mul1.out, b=b), name="mul2") |
| 569 | + mul1.inputs.a = add2.out |
| 570 | + return mul2.out |
| 571 | + |
| 572 | + input_access_workflow = InputAccessWorkflow(a=1, b=2.0, c=[3.0, 4.0]) |
| 573 | + with pytest.raises( |
| 574 | + RuntimeError, match="have already been accessed and therefore cannot set" |
| 575 | + ): |
| 576 | + input_access_workflow.construct() |
0 commit comments