|
| 1 | +from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output |
1 | 2 | from .test_nodes import ( |
2 | 3 | ImageToImageTestInvocation, |
3 | 4 | TextToImageTestInvocation, |
|
20 | 21 |
|
21 | 22 | from invokeai.app.invocations.image import ShowImageInvocation |
22 | 23 | from invokeai.app.invocations.math import AddInvocation, SubtractInvocation |
23 | | -from invokeai.app.invocations.primitives import IntegerInvocation |
| 24 | +from invokeai.app.invocations.primitives import FloatInvocation, IntegerInvocation |
24 | 25 | from invokeai.app.services.default_graphs import create_text_to_image |
25 | 26 | import pytest |
26 | 27 |
|
@@ -610,6 +611,59 @@ def test_graph_can_deserialize(): |
610 | 611 | assert g2.edges[0].destination.field == "image" |
611 | 612 |
|
612 | 613 |
|
| 614 | +def test_invocation_decorator(): |
| 615 | + invocation_type = "test_invocation" |
| 616 | + title = "Test Invocation" |
| 617 | + tags = ["first", "second", "third"] |
| 618 | + category = "category" |
| 619 | + |
| 620 | + @invocation(invocation_type, title=title, tags=tags, category=category) |
| 621 | + class Test(BaseInvocation): |
| 622 | + def invoke(self): |
| 623 | + pass |
| 624 | + |
| 625 | + schema = Test.schema() |
| 626 | + |
| 627 | + assert schema.get("title") == title |
| 628 | + assert schema.get("tags") == tags |
| 629 | + assert schema.get("category") == category |
| 630 | + assert Test(id="1").type == invocation_type # type: ignore (type is dynamically added) |
| 631 | + |
| 632 | + |
| 633 | +def test_invocation_output_decorator(): |
| 634 | + output_type = "test_output" |
| 635 | + |
| 636 | + @invocation_output(output_type) |
| 637 | + class TestOutput(BaseInvocationOutput): |
| 638 | + pass |
| 639 | + |
| 640 | + assert TestOutput().type == output_type # type: ignore (type is dynamically added) |
| 641 | + |
| 642 | + |
| 643 | +def test_floats_accept_ints(): |
| 644 | + g = Graph() |
| 645 | + n1 = IntegerInvocation(id="1", value=1) |
| 646 | + n2 = FloatInvocation(id="2") |
| 647 | + g.add_node(n1) |
| 648 | + g.add_node(n2) |
| 649 | + e = create_edge(n1.id, "value", n2.id, "value") |
| 650 | + |
| 651 | + # Not throwing on this line is sufficient |
| 652 | + g.add_edge(e) |
| 653 | + |
| 654 | + |
| 655 | +def test_ints_do_not_accept_floats(): |
| 656 | + g = Graph() |
| 657 | + n1 = FloatInvocation(id="1", value=1.0) |
| 658 | + n2 = IntegerInvocation(id="2") |
| 659 | + g.add_node(n1) |
| 660 | + g.add_node(n2) |
| 661 | + e = create_edge(n1.id, "value", n2.id, "value") |
| 662 | + |
| 663 | + with pytest.raises(InvalidEdgeError): |
| 664 | + g.add_edge(e) |
| 665 | + |
| 666 | + |
613 | 667 | def test_graph_can_generate_schema(): |
614 | 668 | # Not throwing on this line is sufficient |
615 | 669 | # NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation |
|
0 commit comments