Skip to content

Commit 59cb630

Browse files
feat(tests): add tests for decorator and int -> float
1 parent 920fc0e commit 59cb630

File tree

1 file changed

+55
-1
lines changed

1 file changed

+55
-1
lines changed

tests/nodes/test_node_graph.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
12
from .test_nodes import (
23
ImageToImageTestInvocation,
34
TextToImageTestInvocation,
@@ -20,7 +21,7 @@
2021

2122
from invokeai.app.invocations.image import ShowImageInvocation
2223
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
23-
from invokeai.app.invocations.primitives import IntegerInvocation
24+
from invokeai.app.invocations.primitives import FloatInvocation, IntegerInvocation
2425
from invokeai.app.services.default_graphs import create_text_to_image
2526
import pytest
2627

@@ -610,6 +611,59 @@ def test_graph_can_deserialize():
610611
assert g2.edges[0].destination.field == "image"
611612

612613

614+
def test_invocation_decorator():
615+
invocation_type = "test_invocation"
616+
title = "Test Invocation"
617+
tags = ["first", "second", "third"]
618+
category = "category"
619+
620+
@invocation(invocation_type, title=title, tags=tags, category=category)
621+
class Test(BaseInvocation):
622+
def invoke(self):
623+
pass
624+
625+
schema = Test.schema()
626+
627+
assert schema.get("title") == title
628+
assert schema.get("tags") == tags
629+
assert schema.get("category") == category
630+
assert Test(id="1").type == invocation_type # type: ignore (type is dynamically added)
631+
632+
633+
def test_invocation_output_decorator():
634+
output_type = "test_output"
635+
636+
@invocation_output(output_type)
637+
class TestOutput(BaseInvocationOutput):
638+
pass
639+
640+
assert TestOutput().type == output_type # type: ignore (type is dynamically added)
641+
642+
643+
def test_floats_accept_ints():
644+
g = Graph()
645+
n1 = IntegerInvocation(id="1", value=1)
646+
n2 = FloatInvocation(id="2")
647+
g.add_node(n1)
648+
g.add_node(n2)
649+
e = create_edge(n1.id, "value", n2.id, "value")
650+
651+
# Not throwing on this line is sufficient
652+
g.add_edge(e)
653+
654+
655+
def test_ints_do_not_accept_floats():
656+
g = Graph()
657+
n1 = FloatInvocation(id="1", value=1.0)
658+
n2 = IntegerInvocation(id="2")
659+
g.add_node(n1)
660+
g.add_node(n2)
661+
e = create_edge(n1.id, "value", n2.id, "value")
662+
663+
with pytest.raises(InvalidEdgeError):
664+
g.add_edge(e)
665+
666+
613667
def test_graph_can_generate_schema():
614668
# Not throwing on this line is sufficient
615669
# NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation

0 commit comments

Comments
 (0)