Skip to content

Commit 37636f0

Browse files
chore: fix ruff linting and mypy issues in flow module
1 parent 0e37059 commit 37636f0

File tree

12 files changed

+233
-197
lines changed

12 files changed

+233
-197
lines changed

src/crewai/flow/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from crewai.flow.flow import Flow, start, listen, or_, and_, router
1+
from crewai.flow.flow import Flow, and_, listen, or_, router, start
22
from crewai.flow.persistence import persist
33

4-
__all__ = ["Flow", "start", "listen", "or_", "and_", "router", "persist"]
5-
4+
__all__ = ["Flow", "and_", "listen", "or_", "persist", "router", "start"]

src/crewai/flow/flow_trackable.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import inspect
2-
from typing import Optional
32

43
from pydantic import BaseModel, Field, InstanceOf, model_validator
54

@@ -14,7 +13,7 @@ class FlowTrackable(BaseModel):
1413
inspecting the call stack.
1514
"""
1615

17-
parent_flow: Optional[InstanceOf[Flow]] = Field(
16+
parent_flow: InstanceOf[Flow] | None = Field(
1817
default=None,
1918
description="The parent flow of the instance, if it was created inside a flow.",
2019
)

src/crewai/flow/flow_visualizer.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
# flow_visualizer.py
22

33
import os
4-
from pathlib import Path
54

6-
from pyvis.network import Network
5+
from pyvis.network import Network # type: ignore[import-untyped]
76

87
from crewai.flow.config import COLORS, NODE_STYLES
98
from crewai.flow.html_template_handler import HTMLTemplateHandler
109
from crewai.flow.legend_generator import generate_legend_items_html, get_legend_items
11-
from crewai.flow.path_utils import safe_path_join, validate_path_exists
10+
from crewai.flow.path_utils import safe_path_join
1211
from crewai.flow.utils import calculate_node_levels
1312
from crewai.flow.visualization_utils import (
1413
add_edges,
@@ -34,13 +33,13 @@ def __init__(self, flow):
3433
ValueError
3534
If flow object is invalid or missing required attributes.
3635
"""
37-
if not hasattr(flow, '_methods'):
36+
if not hasattr(flow, "_methods"):
3837
raise ValueError("Invalid flow object: missing '_methods' attribute")
39-
if not hasattr(flow, '_listeners'):
38+
if not hasattr(flow, "_listeners"):
4039
raise ValueError("Invalid flow object: missing '_listeners' attribute")
41-
if not hasattr(flow, '_start_methods'):
40+
if not hasattr(flow, "_start_methods"):
4241
raise ValueError("Invalid flow object: missing '_start_methods' attribute")
43-
42+
4443
self.flow = flow
4544
self.colors = COLORS
4645
self.node_styles = NODE_STYLES
@@ -65,7 +64,7 @@ def plot(self, filename):
6564
"""
6665
if not filename or not isinstance(filename, str):
6766
raise ValueError("Filename must be a non-empty string")
68-
67+
6968
try:
7069
# Initialize network
7170
net = Network(
@@ -96,45 +95,51 @@ def plot(self, filename):
9695
try:
9796
node_levels = calculate_node_levels(self.flow)
9897
except Exception as e:
99-
raise ValueError(f"Failed to calculate node levels: {str(e)}")
98+
raise ValueError(f"Failed to calculate node levels: {e!s}") from e
10099

101100
# Compute positions
102101
try:
103102
node_positions = compute_positions(self.flow, node_levels)
104103
except Exception as e:
105-
raise ValueError(f"Failed to compute node positions: {str(e)}")
104+
raise ValueError(f"Failed to compute node positions: {e!s}") from e
106105

107106
# Add nodes to the network
108107
try:
109108
add_nodes_to_network(net, self.flow, node_positions, self.node_styles)
110109
except Exception as e:
111-
raise RuntimeError(f"Failed to add nodes to network: {str(e)}")
110+
raise RuntimeError(f"Failed to add nodes to network: {e!s}") from e
112111

113112
# Add edges to the network
114113
try:
115114
add_edges(net, self.flow, node_positions, self.colors)
116115
except Exception as e:
117-
raise RuntimeError(f"Failed to add edges to network: {str(e)}")
116+
raise RuntimeError(f"Failed to add edges to network: {e!s}") from e
118117

119118
# Generate HTML
120119
try:
121120
network_html = net.generate_html()
122121
final_html_content = self._generate_final_html(network_html)
123122
except Exception as e:
124-
raise RuntimeError(f"Failed to generate network visualization: {str(e)}")
123+
raise RuntimeError(
124+
f"Failed to generate network visualization: {e!s}"
125+
) from e
125126

126127
# Save the final HTML content to the file
127128
try:
128129
with open(f"{filename}.html", "w", encoding="utf-8") as f:
129130
f.write(final_html_content)
130131
print(f"Plot saved as {filename}.html")
131132
except IOError as e:
132-
raise IOError(f"Failed to save flow visualization to {filename}.html: {str(e)}")
133+
raise IOError(
134+
f"Failed to save flow visualization to {filename}.html: {e!s}"
135+
) from e
133136

134137
except (ValueError, RuntimeError, IOError) as e:
135138
raise e
136139
except Exception as e:
137-
raise RuntimeError(f"Unexpected error during flow visualization: {str(e)}")
140+
raise RuntimeError(
141+
f"Unexpected error during flow visualization: {e!s}"
142+
) from e
138143
finally:
139144
self._cleanup_pyvis_lib()
140145

@@ -165,7 +170,9 @@ def _generate_final_html(self, network_html):
165170
try:
166171
# Extract just the body content from the generated HTML
167172
current_dir = os.path.dirname(__file__)
168-
template_path = safe_path_join("assets", "crewai_flow_visual_template.html", root=current_dir)
173+
template_path = safe_path_join(
174+
"assets", "crewai_flow_visual_template.html", root=current_dir
175+
)
169176
logo_path = safe_path_join("assets", "crewai_logo.svg", root=current_dir)
170177

171178
if not os.path.exists(template_path):
@@ -179,12 +186,9 @@ def _generate_final_html(self, network_html):
179186
# Generate the legend items HTML
180187
legend_items = get_legend_items(self.colors)
181188
legend_items_html = generate_legend_items_html(legend_items)
182-
final_html_content = html_handler.generate_final_html(
183-
network_body, legend_items_html
184-
)
185-
return final_html_content
189+
return html_handler.generate_final_html(network_body, legend_items_html)
186190
except Exception as e:
187-
raise IOError(f"Failed to generate visualization HTML: {str(e)}")
191+
raise IOError(f"Failed to generate visualization HTML: {e!s}") from e
188192

189193
def _cleanup_pyvis_lib(self):
190194
"""
@@ -197,6 +201,7 @@ def _cleanup_pyvis_lib(self):
197201
lib_folder = safe_path_join("lib", root=os.getcwd())
198202
if os.path.exists(lib_folder) and os.path.isdir(lib_folder):
199203
import shutil
204+
200205
shutil.rmtree(lib_folder)
201206
except ValueError as e:
202207
print(f"Error validating lib folder path: {e}")

src/crewai/flow/html_template_handler.py

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import base64
22
import re
3-
from pathlib import Path
43

5-
from crewai.flow.path_utils import safe_path_join, validate_path_exists
4+
from crewai.flow.path_utils import validate_path_exists
65

76

87
class HTMLTemplateHandler:
@@ -28,7 +27,7 @@ def __init__(self, template_path, logo_path):
2827
self.template_path = validate_path_exists(template_path, "file")
2928
self.logo_path = validate_path_exists(logo_path, "file")
3029
except ValueError as e:
31-
raise ValueError(f"Invalid template or logo path: {e}")
30+
raise ValueError(f"Invalid template or logo path: {e}") from e
3231

3332
def read_template(self):
3433
"""Read and return the HTML template file contents."""
@@ -53,23 +52,23 @@ def generate_legend_items_html(self, legend_items):
5352
if "border" in item:
5453
legend_items_html += f"""
5554
<div class="legend-item">
56-
<div class="legend-color-box" style="background-color: {item['color']}; border: 2px dashed {item['border']};"></div>
57-
<div>{item['label']}</div>
55+
<div class="legend-color-box" style="background-color: {item["color"]}; border: 2px dashed {item["border"]};"></div>
56+
<div>{item["label"]}</div>
5857
</div>
5958
"""
6059
elif item.get("dashed") is not None:
6160
style = "dashed" if item["dashed"] else "solid"
6261
legend_items_html += f"""
6362
<div class="legend-item">
64-
<div class="legend-{style}" style="border-bottom: 2px {style} {item['color']};"></div>
65-
<div>{item['label']}</div>
63+
<div class="legend-{style}" style="border-bottom: 2px {style} {item["color"]};"></div>
64+
<div>{item["label"]}</div>
6665
</div>
6766
"""
6867
else:
6968
legend_items_html += f"""
7069
<div class="legend-item">
71-
<div class="legend-color-box" style="background-color: {item['color']};"></div>
72-
<div>{item['label']}</div>
70+
<div class="legend-color-box" style="background-color: {item["color"]};"></div>
71+
<div>{item["label"]}</div>
7372
</div>
7473
"""
7574
return legend_items_html
@@ -79,15 +78,9 @@ def generate_final_html(self, network_body, legend_items_html, title="Flow Plot"
7978
html_template = self.read_template()
8079
logo_svg_base64 = self.encode_logo()
8180

82-
final_html_content = html_template.replace("{{ title }}", title)
83-
final_html_content = final_html_content.replace(
84-
"{{ network_content }}", network_body
81+
return (
82+
html_template.replace("{{ title }}", title)
83+
.replace("{{ network_content }}", network_body)
84+
.replace("{{ logo_svg_base64 }}", logo_svg_base64)
85+
.replace("<!-- LEGEND_ITEMS_PLACEHOLDER -->", legend_items_html)
8586
)
86-
final_html_content = final_html_content.replace(
87-
"{{ logo_svg_base64 }}", logo_svg_base64
88-
)
89-
final_html_content = final_html_content.replace(
90-
"<!-- LEGEND_ITEMS_PLACEHOLDER -->", legend_items_html
91-
)
92-
93-
return final_html_content

src/crewai/flow/path_utils.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,10 @@
55
traversal attacks and ensure paths remain within allowed boundaries.
66
"""
77

8-
import os
98
from pathlib import Path
10-
from typing import List, Union
119

1210

13-
def safe_path_join(*parts: str, root: Union[str, Path, None] = None) -> str:
11+
def safe_path_join(*parts: str, root: str | Path | None = None) -> str:
1412
"""
1513
Safely join path components and ensure the result is within allowed boundaries.
1614
@@ -43,25 +41,25 @@ def safe_path_join(*parts: str, root: Union[str, Path, None] = None) -> str:
4341

4442
# Establish root directory
4543
root_path = Path(root).resolve() if root else Path.cwd()
46-
44+
4745
# Join and resolve the full path
4846
full_path = Path(root_path, *clean_parts).resolve()
49-
47+
5048
# Check if the resolved path is within root
5149
if not str(full_path).startswith(str(root_path)):
5250
raise ValueError(
5351
f"Invalid path: Potential directory traversal. Path must be within {root_path}"
5452
)
55-
53+
5654
return str(full_path)
57-
55+
5856
except Exception as e:
5957
if isinstance(e, ValueError):
6058
raise
61-
raise ValueError(f"Invalid path components: {str(e)}")
59+
raise ValueError(f"Invalid path components: {e!s}") from e
6260

6361

64-
def validate_path_exists(path: Union[str, Path], file_type: str = "file") -> str:
62+
def validate_path_exists(path: str | Path, file_type: str = "file") -> str:
6563
"""
6664
Validate that a path exists and is of the expected type.
6765
@@ -84,24 +82,24 @@ def validate_path_exists(path: Union[str, Path], file_type: str = "file") -> str
8482
"""
8583
try:
8684
path_obj = Path(path).resolve()
87-
85+
8886
if not path_obj.exists():
8987
raise ValueError(f"Path does not exist: {path}")
90-
88+
9189
if file_type == "file" and not path_obj.is_file():
9290
raise ValueError(f"Path is not a file: {path}")
93-
elif file_type == "directory" and not path_obj.is_dir():
91+
if file_type == "directory" and not path_obj.is_dir():
9492
raise ValueError(f"Path is not a directory: {path}")
95-
93+
9694
return str(path_obj)
97-
95+
9896
except Exception as e:
9997
if isinstance(e, ValueError):
10098
raise
101-
raise ValueError(f"Invalid path: {str(e)}")
99+
raise ValueError(f"Invalid path: {e!s}") from e
102100

103101

104-
def list_files(directory: Union[str, Path], pattern: str = "*") -> List[str]:
102+
def list_files(directory: str | Path, pattern: str = "*") -> list[str]:
105103
"""
106104
Safely list files in a directory matching a pattern.
107105
@@ -126,10 +124,10 @@ def list_files(directory: Union[str, Path], pattern: str = "*") -> List[str]:
126124
dir_path = Path(directory).resolve()
127125
if not dir_path.is_dir():
128126
raise ValueError(f"Not a directory: {directory}")
129-
127+
130128
return [str(p) for p in dir_path.glob(pattern) if p.is_file()]
131-
129+
132130
except Exception as e:
133131
if isinstance(e, ValueError):
134132
raise
135-
raise ValueError(f"Error listing files: {str(e)}")
133+
raise ValueError(f"Error listing files: {e!s}") from e

src/crewai/flow/persistence/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44
This module provides interfaces and implementations for persisting flow states.
55
"""
66

7-
from typing import Any, Dict, TypeVar, Union
7+
from typing import Any, TypeVar
88

99
from pydantic import BaseModel
1010

1111
from crewai.flow.persistence.base import FlowPersistence
1212
from crewai.flow.persistence.decorators import persist
1313
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
1414

15-
__all__ = ["FlowPersistence", "persist", "SQLiteFlowPersistence"]
15+
__all__ = ["FlowPersistence", "SQLiteFlowPersistence", "persist"]
1616

17-
StateType = TypeVar('StateType', bound=Union[Dict[str, Any], BaseModel])
18-
DictStateType = Dict[str, Any]
17+
StateType = TypeVar("StateType", bound=dict[str, Any] | BaseModel)
18+
DictStateType = dict[str, Any]

0 commit comments

Comments
 (0)