Skip to content

Commit 3dbb0e1

Browse files
feat(tests): add tests for node versions
1 parent d6317bc commit 3dbb0e1

File tree

2 files changed

+39
-6
lines changed

2 files changed

+39
-6
lines changed

invokeai/app/invocations/baseinvocation.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@
3232
from ..services.invocation_services import InvocationServices
3333

3434

35+
class InvalidVersionError(ValueError):
36+
pass
37+
38+
3539
class FieldDescriptions:
3640
denoising_start = "When to start denoising, expressed a percentage of total steps"
3741
denoising_end = "When to stop denoising, expressed a percentage of total steps"
@@ -605,7 +609,10 @@ def wrapper(cls: Type[GenericBaseInvocation]) -> Type[GenericBaseInvocation]:
605609
if category is not None:
606610
cls.UIConfig.category = category
607611
if version is not None:
608-
semver.Version.parse(version) # raises ValueError if invalid semver
612+
try:
613+
semver.Version.parse(version)
614+
except ValueError as e:
615+
raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e
609616
cls.UIConfig.version = version
610617

611618
# Add the invocation type to the pydantic model of the invocation

tests/nodes/test_node_graph.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1-
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
1+
from invokeai.app.invocations.baseinvocation import (
2+
BaseInvocation,
3+
BaseInvocationOutput,
4+
InvalidVersionError,
5+
invocation,
6+
invocation_output,
7+
)
28
from .test_nodes import (
39
ImageToImageTestInvocation,
410
TextToImageTestInvocation,
@@ -616,18 +622,38 @@ def test_invocation_decorator():
616622
title = "Test Invocation"
617623
tags = ["first", "second", "third"]
618624
category = "category"
625+
version = "1.2.3"
619626

620-
@invocation(invocation_type, title=title, tags=tags, category=category)
621-
class Test(BaseInvocation):
627+
@invocation(invocation_type, title=title, tags=tags, category=category, version=version)
628+
class TestInvocation(BaseInvocation):
622629
def invoke(self):
623630
pass
624631

625-
schema = Test.schema()
632+
schema = TestInvocation.schema()
626633

627634
assert schema.get("title") == title
628635
assert schema.get("tags") == tags
629636
assert schema.get("category") == category
630-
assert Test(id="1").type == invocation_type # type: ignore (type is dynamically added)
637+
assert schema.get("version") == version
638+
assert TestInvocation(id="1").type == invocation_type # type: ignore (type is dynamically added)
639+
640+
641+
def test_invocation_version_must_be_semver():
642+
invocation_type = "test_invocation"
643+
valid_version = "1.0.0"
644+
invalid_version = "not_semver"
645+
646+
@invocation(invocation_type, version=valid_version)
647+
class ValidVersionInvocation(BaseInvocation):
648+
def invoke(self):
649+
pass
650+
651+
with pytest.raises(InvalidVersionError):
652+
653+
@invocation(invocation_type, version=invalid_version)
654+
class InvalidVersionInvocation(BaseInvocation):
655+
def invoke(self):
656+
pass
631657

632658

633659
def test_invocation_output_decorator():

0 commit comments

Comments
 (0)