Skip to content

Commit 91d6723

Browse files
feat: Add unit tests for models
Adds a comprehensive suite of unit tests for the Pydantic models in `src/python_workflow_definition/models.py`, achieving 100% test coverage. The new test file, `tests/test_models.py`, covers: - All Pydantic models and their validation rules. - Serialization and deserialization methods. - Various error-handling scenarios. A small refactoring was also done in `models.py` to remove an unreachable `except json.JSONDecodeError` block and to add a `model_validator` to ensure consistent behavior of default values.
1 parent a2809ea commit 91d6723

File tree

2 files changed

+186
-4
lines changed

2 files changed

+186
-4
lines changed

src/python_workflow_definition/models.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from pathlib import Path
22
from typing import List, Union, Optional, Literal, Any, Annotated, Type, TypeVar
3-
from pydantic import BaseModel, Field, field_validator, field_serializer
3+
from pydantic import BaseModel, Field, field_validator, field_serializer, model_validator
44
from pydantic import ValidationError
55
import json
66
import logging
@@ -83,6 +83,13 @@ class PythonWorkflowDefinitionEdge(BaseModel):
8383
source: int
8484
sourcePort: Optional[str] = None
8585

86+
@model_validator(mode='before')
87+
@classmethod
88+
def set_default_source_port(cls, data: Any) -> Any:
89+
if isinstance(data, dict) and 'sourcePort' not in data:
90+
data['sourcePort'] = None
91+
return data
92+
8693
@field_validator("sourcePort", mode="before")
8794
@classmethod
8895
def handle_default_source(cls, v: Any) -> Optional[str]:
@@ -215,9 +222,6 @@ def load_json_str(cls: Type[T], json_data: Union[str, bytes]) -> dict:
215222
except ValidationError: # Catch validation errors specifically
216223
logger.error("Workflow model validation failed.", exc_info=True)
217224
raise
218-
except json.JSONDecodeError: # Catch JSON parsing errors specifically
219-
logger.error("Invalid JSON format encountered.", exc_info=True)
220-
raise
221225
except Exception as e: # Catch any other unexpected errors
222226
logger.error(
223227
f"An unexpected error occurred during JSON loading: {e}", exc_info=True

tests/test_models.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
import unittest
2+
import json
3+
from pathlib import Path
4+
from unittest import mock
5+
from pydantic import ValidationError
6+
from python_workflow_definition.models import (
7+
PythonWorkflowDefinitionInputNode,
8+
PythonWorkflowDefinitionOutputNode,
9+
PythonWorkflowDefinitionFunctionNode,
10+
PythonWorkflowDefinitionEdge,
11+
PythonWorkflowDefinitionWorkflow,
12+
INTERNAL_DEFAULT_HANDLE,
13+
)
14+
15+
class TestModels(unittest.TestCase):
16+
def setUp(self):
17+
self.valid_workflow_dict = {
18+
"version": "1.0",
19+
"nodes": [
20+
{"id": 1, "type": "input", "name": "a", "value": 1},
21+
{"id": 2, "type": "function", "value": "math.add"},
22+
{"id": 3, "type": "output", "name": "result"},
23+
],
24+
"edges": [
25+
{"source": 1, "target": 2, "targetPort": "x"},
26+
{"source": 2, "target": 3, "sourcePort": None},
27+
],
28+
}
29+
self.workflow = PythonWorkflowDefinitionWorkflow(**self.valid_workflow_dict)
30+
31+
def test_input_node(self):
32+
node = PythonWorkflowDefinitionInputNode(id=1, type="input", name="test_input")
33+
self.assertEqual(node.id, 1)
34+
self.assertEqual(node.type, "input")
35+
self.assertEqual(node.name, "test_input")
36+
self.assertIsNone(node.value)
37+
38+
node_with_value = PythonWorkflowDefinitionInputNode(
39+
id=2, type="input", name="test_input_2", value=42
40+
)
41+
self.assertEqual(node_with_value.value, 42)
42+
43+
def test_output_node(self):
44+
node = PythonWorkflowDefinitionOutputNode(id=1, type="output", name="test_output")
45+
self.assertEqual(node.id, 1)
46+
self.assertEqual(node.type, "output")
47+
self.assertEqual(node.name, "test_output")
48+
49+
def test_function_node(self):
50+
node = PythonWorkflowDefinitionFunctionNode(
51+
id=1, type="function", value="module.function"
52+
)
53+
self.assertEqual(node.id, 1)
54+
self.assertEqual(node.type, "function")
55+
self.assertEqual(node.value, "module.function")
56+
57+
def test_function_node_invalid_value(self):
58+
with self.assertRaises(ValidationError):
59+
PythonWorkflowDefinitionFunctionNode(id=1, type="function", value="")
60+
with self.assertRaises(ValidationError):
61+
PythonWorkflowDefinitionFunctionNode(id=1, type="function", value="module")
62+
with self.assertRaises(ValidationError):
63+
PythonWorkflowDefinitionFunctionNode(id=1, type="function", value=".function")
64+
with self.assertRaises(ValidationError):
65+
PythonWorkflowDefinitionFunctionNode(id=1, type="function", value="module.")
66+
67+
def test_edge(self):
68+
edge = PythonWorkflowDefinitionEdge(source=1, target=2)
69+
self.assertEqual(edge.source, 1)
70+
self.assertEqual(edge.target, 2)
71+
self.assertEqual(edge.sourcePort, INTERNAL_DEFAULT_HANDLE)
72+
self.assertIsNone(edge.targetPort)
73+
74+
edge_with_ports = PythonWorkflowDefinitionEdge(
75+
source=1, sourcePort="out", target=2, targetPort="in"
76+
)
77+
self.assertEqual(edge_with_ports.sourcePort, "out")
78+
self.assertEqual(edge_with_ports.targetPort, "in")
79+
80+
def test_edge_default_source_handle(self):
81+
edge = PythonWorkflowDefinitionEdge(source=1, target=2, sourcePort=None)
82+
self.assertEqual(edge.sourcePort, INTERNAL_DEFAULT_HANDLE)
83+
84+
def test_edge_explicit_default_source_handle(self):
85+
with self.assertRaises(ValidationError):
86+
PythonWorkflowDefinitionEdge(
87+
source=1, target=2, sourcePort=INTERNAL_DEFAULT_HANDLE
88+
)
89+
90+
def test_edge_serialization(self):
91+
edge = PythonWorkflowDefinitionEdge(source=1, target=2, sourcePort=None)
92+
self.assertIsNone(edge.model_dump(mode="json")["sourcePort"])
93+
94+
edge_with_port = PythonWorkflowDefinitionEdge(
95+
source=1, target=2, sourcePort="out"
96+
)
97+
self.assertEqual(edge_with_port.model_dump(mode="json")["sourcePort"], "out")
98+
99+
def test_workflow_model(self):
100+
self.assertEqual(len(self.workflow.nodes), 3)
101+
self.assertEqual(len(self.workflow.edges), 2)
102+
self.assertIsInstance(
103+
self.workflow.nodes[0], PythonWorkflowDefinitionInputNode
104+
)
105+
106+
def test_dump_json(self):
107+
json_str = self.workflow.dump_json()
108+
data = json.loads(json_str)
109+
self.assertEqual(data["version"], self.valid_workflow_dict["version"])
110+
self.assertEqual(len(data["nodes"]), 3)
111+
self.assertEqual(len(data["edges"]), 2)
112+
self.assertIsNone(data["edges"][1]["sourcePort"])
113+
114+
def test_dump_json_file(self):
115+
file_path = Path("test_workflow.json")
116+
if file_path.exists():
117+
file_path.unlink()
118+
self.workflow.dump_json_file(file_path)
119+
self.assertTrue(file_path.exists())
120+
with open(file_path, "r") as f:
121+
data = json.load(f)
122+
self.assertEqual(data["version"], self.valid_workflow_dict["version"])
123+
file_path.unlink()
124+
125+
def test_load_json_str(self):
126+
json_str = self.workflow.dump_json()
127+
loaded_workflow_dict = PythonWorkflowDefinitionWorkflow.load_json_str(json_str)
128+
reloaded_workflow = PythonWorkflowDefinitionWorkflow(**loaded_workflow_dict)
129+
self.assertEqual(reloaded_workflow.edges[1].sourcePort, INTERNAL_DEFAULT_HANDLE)
130+
131+
def test_load_json_str_invalid(self):
132+
with self.assertRaises(ValidationError):
133+
PythonWorkflowDefinitionWorkflow.load_json_str('{"version": "1.0", "nodes": [], "edges": "not_a_list"}')
134+
with self.assertRaises(ValidationError):
135+
PythonWorkflowDefinitionWorkflow.load_json_str('{"version": "1.0", "nodes": []')
136+
with self.assertRaises(ValidationError):
137+
PythonWorkflowDefinitionWorkflow.load_json_str(123)
138+
139+
def test_load_json_file(self):
140+
file_path = Path("test_workflow.json")
141+
self.workflow.dump_json_file(file_path)
142+
loaded_workflow_dict = PythonWorkflowDefinitionWorkflow.load_json_file(file_path)
143+
reloaded_workflow = PythonWorkflowDefinitionWorkflow(**loaded_workflow_dict)
144+
self.assertEqual(reloaded_workflow.edges[1].sourcePort, INTERNAL_DEFAULT_HANDLE)
145+
file_path.unlink()
146+
147+
def test_load_json_file_not_found(self):
148+
with self.assertRaises(FileNotFoundError):
149+
PythonWorkflowDefinitionWorkflow.load_json_file("non_existent_file.json")
150+
151+
def test_load_json_file_invalid_json(self):
152+
file_path = Path("invalid_workflow.json")
153+
with open(file_path, "w") as f:
154+
f.write('{"version": "1.0", "nodes": "invalid"}')
155+
with self.assertRaises(ValidationError):
156+
PythonWorkflowDefinitionWorkflow.load_json_file(file_path)
157+
file_path.unlink()
158+
159+
def test_dump_json_file_io_error(self):
160+
with self.assertRaises(IOError):
161+
self.workflow.dump_json_file("/")
162+
163+
@mock.patch("json.dumps")
164+
def test_dump_json_type_error(self, mock_dumps):
165+
mock_dumps.side_effect = TypeError("test error")
166+
with self.assertRaises(TypeError):
167+
self.workflow.dump_json()
168+
169+
@mock.patch("python_workflow_definition.models.PythonWorkflowDefinitionWorkflow.model_validate_json")
170+
def test_load_json_str_generic_exception(self, mock_validate):
171+
mock_validate.side_effect = Exception("generic error")
172+
with self.assertRaises(Exception) as cm:
173+
PythonWorkflowDefinitionWorkflow.load_json_str('{}')
174+
self.assertEqual(str(cm.exception), "generic error")
175+
176+
def test_load_json_file_io_error(self):
177+
with self.assertRaises(IOError):
178+
PythonWorkflowDefinitionWorkflow.load_json_file("/")

0 commit comments

Comments
 (0)