Skip to content

Commit 4e402b1

Browse files
authored
Reland union type (#5900)
* Reapply "Add union link connection type support (#5806)" (#5889) This reverts commit bf9a90a. * Fix union type breaks existing type workarounds * Add non-string test * Add tests for hacks and non-string types * Support python versions lower than 3.11
1 parent 4827244 commit 4e402b1

File tree

3 files changed

+161
-3
lines changed

3 files changed

+161
-3
lines changed

comfy_execution/validation.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from __future__ import annotations
2+
3+
4+
def validate_node_input(
5+
received_type: str, input_type: str, strict: bool = False
6+
) -> bool:
7+
"""
8+
received_type and input_type are both strings of the form "T1,T2,...".
9+
10+
If strict is True, the input_type must contain the received_type.
11+
For example, if received_type is "STRING" and input_type is "STRING,INT",
12+
this will return True. But if received_type is "STRING,INT" and input_type is
13+
"INT", this will return False.
14+
15+
If strict is False, the input_type must have overlap with the received_type.
16+
For example, if received_type is "STRING,BOOLEAN" and input_type is "STRING,INT",
17+
this will return True.
18+
19+
Supports pre-union type extension behaviour of ``__ne__`` overrides.
20+
"""
21+
# If the types are exactly the same, we can return immediately
22+
# Use pre-union behaviour: inverse of `__ne__`
23+
if not received_type != input_type:
24+
return True
25+
26+
# Not equal, and not strings
27+
if not isinstance(received_type, str) or not isinstance(input_type, str):
28+
return False
29+
30+
# Split the type strings into sets for comparison
31+
received_types = set(t.strip() for t in received_type.split(","))
32+
input_types = set(t.strip() for t in input_type.split(","))
33+
34+
if strict:
35+
# In strict mode, all received types must be in the input types
36+
return received_types.issubset(input_types)
37+
else:
38+
# In non-strict mode, there must be at least one type in common
39+
return len(received_types.intersection(input_types)) > 0

execution.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
1717
from comfy_execution.graph_utils import is_link, GraphBuilder
1818
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
19+
from comfy_execution.validation import validate_node_input
1920
from comfy.cli_args import args
2021

2122
class ExecutionResult(Enum):
@@ -527,7 +528,6 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
527528
comfy.model_management.unload_all_models()
528529

529530

530-
531531
def validate_inputs(prompt, item, validated):
532532
unique_id = item
533533
if unique_id in validated:
@@ -589,8 +589,8 @@ def validate_inputs(prompt, item, validated):
589589
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
590590
received_type = r[val[1]]
591591
received_types[x] = received_type
592-
if 'input_types' not in validate_function_inputs and received_type != type_input:
593-
details = f"{x}, {received_type} != {type_input}"
592+
if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, type_input):
593+
details = f"{x}, received_type({received_type}) mismatch input_type({type_input})"
594594
error = {
595595
"type": "return_type_mismatch",
596596
"message": "Return type mismatch between linked nodes",
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import pytest
2+
from comfy_execution.validation import validate_node_input
3+
4+
5+
def test_exact_match():
6+
"""Test cases where types match exactly"""
7+
assert validate_node_input("STRING", "STRING")
8+
assert validate_node_input("STRING,INT", "STRING,INT")
9+
assert validate_node_input("INT,STRING", "STRING,INT") # Order shouldn't matter
10+
11+
12+
def test_strict_mode():
13+
"""Test strict mode validation"""
14+
# Should pass - received type is subset of input type
15+
assert validate_node_input("STRING", "STRING,INT", strict=True)
16+
assert validate_node_input("INT", "STRING,INT", strict=True)
17+
assert validate_node_input("STRING,INT", "STRING,INT,BOOLEAN", strict=True)
18+
19+
# Should fail - received type is not subset of input type
20+
assert not validate_node_input("STRING,INT", "STRING", strict=True)
21+
assert not validate_node_input("STRING,BOOLEAN", "STRING", strict=True)
22+
assert not validate_node_input("INT,BOOLEAN", "STRING,INT", strict=True)
23+
24+
25+
def test_non_strict_mode():
26+
"""Test non-strict mode validation (default behavior)"""
27+
# Should pass - types have overlap
28+
assert validate_node_input("STRING,BOOLEAN", "STRING,INT")
29+
assert validate_node_input("STRING,INT", "INT,BOOLEAN")
30+
assert validate_node_input("STRING", "STRING,INT")
31+
32+
# Should fail - no overlap in types
33+
assert not validate_node_input("BOOLEAN", "STRING,INT")
34+
assert not validate_node_input("FLOAT", "STRING,INT")
35+
assert not validate_node_input("FLOAT,BOOLEAN", "STRING,INT")
36+
37+
38+
def test_whitespace_handling():
39+
"""Test that whitespace is handled correctly"""
40+
assert validate_node_input("STRING, INT", "STRING,INT")
41+
assert validate_node_input("STRING,INT", "STRING, INT")
42+
assert validate_node_input(" STRING , INT ", "STRING,INT")
43+
assert validate_node_input("STRING,INT", " STRING , INT ")
44+
45+
46+
def test_empty_strings():
47+
"""Test behavior with empty strings"""
48+
assert validate_node_input("", "")
49+
assert not validate_node_input("STRING", "")
50+
assert not validate_node_input("", "STRING")
51+
52+
53+
def test_single_vs_multiple():
54+
"""Test single type against multiple types"""
55+
assert validate_node_input("STRING", "STRING,INT,BOOLEAN")
56+
assert validate_node_input("STRING,INT,BOOLEAN", "STRING", strict=False)
57+
assert not validate_node_input("STRING,INT,BOOLEAN", "STRING", strict=True)
58+
59+
60+
def test_non_string():
61+
"""Test non-string types"""
62+
obj1 = object()
63+
obj2 = object()
64+
assert validate_node_input(obj1, obj1)
65+
assert not validate_node_input(obj1, obj2)
66+
67+
68+
class NotEqualsOverrideTest(str):
69+
"""Test class for ``__ne__`` override."""
70+
71+
def __ne__(self, value: object) -> bool:
72+
if self == "*" or value == "*":
73+
return False
74+
if self == "LONGER_THAN_2":
75+
return not len(value) > 2
76+
raise TypeError("This is a class for unit tests only.")
77+
78+
79+
def test_ne_override():
80+
"""Test ``__ne__`` any override"""
81+
any = NotEqualsOverrideTest("*")
82+
invalid_type = "INVALID_TYPE"
83+
obj = object()
84+
assert validate_node_input(any, any)
85+
assert validate_node_input(any, invalid_type)
86+
assert validate_node_input(any, obj)
87+
assert validate_node_input(any, {})
88+
assert validate_node_input(any, [])
89+
assert validate_node_input(any, [1, 2, 3])
90+
91+
92+
def test_ne_custom_override():
93+
"""Test ``__ne__`` custom override"""
94+
special = NotEqualsOverrideTest("LONGER_THAN_2")
95+
96+
assert validate_node_input(special, special)
97+
assert validate_node_input(special, "*")
98+
assert validate_node_input(special, "INVALID_TYPE")
99+
assert validate_node_input(special, [1, 2, 3])
100+
101+
# Should fail
102+
assert not validate_node_input(special, [1, 2])
103+
assert not validate_node_input(special, "TY")
104+
105+
106+
@pytest.mark.parametrize(
107+
"received,input_type,strict,expected",
108+
[
109+
("STRING", "STRING", False, True),
110+
("STRING,INT", "STRING,INT", False, True),
111+
("STRING", "STRING,INT", True, True),
112+
("STRING,INT", "STRING", True, False),
113+
("BOOLEAN", "STRING,INT", False, False),
114+
("STRING,BOOLEAN", "STRING,INT", False, True),
115+
],
116+
)
117+
def test_parametrized_cases(received, input_type, strict, expected):
118+
"""Parametrized test cases for various scenarios"""
119+
assert validate_node_input(received, input_type, strict) == expected

0 commit comments

Comments
 (0)