1
+ # -*- coding: utf-8 -*-
1
2
"""
2
3
Validate async code patterns and detect common pitfalls.
3
4
"""
10
11
11
12
class AsyncCodeValidator :
12
13
"""Validate async code for common patterns and pitfalls."""
13
-
14
+
14
15
def __init__ (self ):
15
16
self .issues = []
16
17
self .suggestions = []
17
-
18
+
18
19
def validate_directory (self , source_dir : Path ) -> Dict [str , Any ]:
19
20
"""Validate all Python files in directory."""
20
-
21
+
21
22
validation_results = {
22
23
'files_checked' : 0 ,
23
24
'issues_found' : 0 ,
24
25
'suggestions' : 0 ,
25
26
'details' : []
26
27
}
27
-
28
+
28
29
python_files = list (source_dir .rglob ("*.py" ))
29
-
30
+
30
31
for file_path in python_files :
31
32
if self ._should_skip_file (file_path ):
32
33
continue
33
-
34
+
34
35
file_results = self ._validate_file (file_path )
35
36
validation_results ['details' ].append (file_results )
36
37
validation_results ['files_checked' ] += 1
37
38
validation_results ['issues_found' ] += len (file_results ['issues' ])
38
39
validation_results ['suggestions' ] += len (file_results ['suggestions' ])
39
-
40
+
40
41
return validation_results
41
-
42
+
42
43
def _validate_file (self , file_path : Path ) -> Dict [str , Any ]:
43
44
"""Validate a single Python file."""
44
-
45
+
45
46
file_results = {
46
47
'file' : str (file_path ),
47
48
'issues' : [],
48
49
'suggestions' : []
49
50
}
50
-
51
+
51
52
try :
52
53
with open (file_path , 'r' , encoding = 'utf-8' ) as f :
53
54
source_code = f .read ()
54
-
55
+
55
56
tree = ast .parse (source_code , filename = str (file_path ))
56
-
57
+
57
58
# Analyze AST for async patterns
58
59
validator = AsyncPatternVisitor (file_path )
59
60
validator .visit (tree )
60
-
61
+
61
62
file_results ['issues' ] = validator .issues
62
63
file_results ['suggestions' ] = validator .suggestions
63
-
64
+
64
65
except Exception as e :
65
66
file_results ['issues' ].append ({
66
67
'type' : 'parse_error' ,
67
68
'message' : f"Failed to parse file: { str (e )} " ,
68
69
'line' : 0
69
70
})
70
-
71
+
71
72
return file_results
72
-
73
-
73
+
74
+
74
75
def _should_skip_file (self , file_path : Path ) -> bool :
75
76
"""Determine if a file should be skipped (e.g., __init__.py files)."""
76
77
return file_path .name == "__init__.py"
77
-
78
+
78
79
class AsyncPatternVisitor (ast .NodeVisitor ):
79
80
"""AST visitor to detect async patterns and issues."""
80
-
81
+
81
82
def __init__ (self , file_path : Path ):
82
83
self .file_path = file_path
83
84
self .issues = []
84
85
self .suggestions = []
85
86
self .in_async_function = False
86
-
87
+
87
88
def visit_AsyncFunctionDef (self , node ):
88
89
"""Visit async function definitions."""
89
-
90
+
90
91
self .in_async_function = True
91
-
92
+
92
93
# Check for blocking operations in async functions
93
94
self ._check_blocking_operations (node )
94
-
95
+
95
96
# Check for proper error handling
96
97
self ._check_error_handling (node )
97
-
98
+
98
99
self .generic_visit (node )
99
100
self .in_async_function = False
100
-
101
+
101
102
def visit_Call (self , node ):
102
103
"""Visit function calls."""
103
-
104
+
104
105
if self .in_async_function :
105
106
# Check for potentially unawaited async calls
106
107
self ._check_unawaited_calls (node )
107
-
108
+
108
109
# Check for blocking I/O operations
109
110
self ._check_blocking_io (node )
110
-
111
+
111
112
self .generic_visit (node )
112
-
113
+
113
114
def _check_blocking_operations (self , node ):
114
115
"""Check for blocking operations in async functions."""
115
-
116
+
116
117
blocking_patterns = [
117
118
'time.sleep' ,
118
119
'requests.get' , 'requests.post' ,
119
120
'subprocess.run' , 'subprocess.call' ,
120
121
'open' # File I/O without async
121
122
]
122
-
123
+
123
124
for child in ast .walk (node ):
124
125
if isinstance (child , ast .Call ):
125
126
call_name = self ._get_call_name (child )
@@ -130,18 +131,18 @@ def _check_blocking_operations(self, node):
130
131
'line' : child .lineno ,
131
132
'suggestion' : f"Use async equivalent of { call_name } "
132
133
})
133
-
134
+
134
135
def _check_unawaited_calls (self , node ):
135
136
"""Check for potentially unawaited async calls."""
136
-
137
+
137
138
# Look for calls that might return coroutines
138
139
async_patterns = [
139
140
'aiohttp' , 'asyncio' , 'asyncpg' ,
140
141
'websockets' , 'motor' # Common async libraries
141
142
]
142
-
143
+
143
144
call_name = self ._get_call_name (node )
144
-
145
+
145
146
for pattern in async_patterns :
146
147
if pattern in call_name :
147
148
# Check if this call is awaited
@@ -153,10 +154,10 @@ def _check_unawaited_calls(self, node):
153
154
'line' : node .lineno
154
155
})
155
156
break
156
-
157
+
157
158
def _get_call_name (self , node ):
158
159
"""Extract the name of a function call."""
159
-
160
+
160
161
if isinstance (node .func , ast .Name ):
161
162
return node .func .id
162
163
elif isinstance (node .func , ast .Attribute ):
@@ -165,19 +166,19 @@ def _get_call_name(self, node):
165
166
else :
166
167
return node .func .attr
167
168
return "unknown"
168
-
169
+
169
170
170
171
if __name__ == "__main__" :
171
172
parser = argparse .ArgumentParser (description = "Validate async code patterns and detect common pitfalls." )
172
173
parser .add_argument ("--source" , type = Path , required = True , help = "Source directory to validate." )
173
174
parser .add_argument ("--report" , type = Path , required = True , help = "Path to the output validation report." )
174
-
175
+
175
176
args = parser .parse_args ()
176
-
177
+
177
178
validator = AsyncCodeValidator ()
178
179
results = validator .validate_directory (args .source )
179
-
180
+
180
181
with open (args .report , 'w' ) as f :
181
182
json .dump (results , f , indent = 4 )
182
-
183
- print (f"Validation report saved to { args .report } " )
183
+
184
+ print (f"Validation report saved to { args .report } " )
0 commit comments