Skip to content

Commit 46da56e

Browse files
committed
Add pre-commit hooks with Ruff, Black, and mypy
1 parent f6c6860 commit 46da56e

File tree

3 files changed

+61
-20
lines changed

3 files changed

+61
-20
lines changed

.pre-commit-config.yaml

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Pre-commit hooks for code quality
2+
repos:
3+
- repo: https://github.com/pre-commit/pre-commit-hooks
4+
rev: v4.5.0
5+
hooks:
6+
- id: trailing-whitespace
7+
- id: end-of-file-fixer
8+
- id: check-yaml
9+
- id: check-added-large-files
10+
- id: check-merge-conflict
11+
- id: check-toml
12+
13+
- repo: https://github.com/astral-sh/ruff-pre-commit
14+
rev: v0.1.6
15+
hooks:
16+
- id: ruff
17+
args: [--fix]
18+
- id: ruff-format
19+
20+
- repo: https://github.com/psf/black
21+
rev: 23.11.0
22+
hooks:
23+
- id: black
24+
25+
- repo: https://github.com/pre-commit/mirrors-mypy
26+
rev: v1.7.0
27+
hooks:
28+
- id: mypy
29+
additional_dependencies:
30+
- pydantic>=2.5.0
31+
- aiohttp>=3.9.0
32+
- click>=8.1.0
33+
- beautifulsoup4>=4.12.0
34+
args: [--ignore-missing-imports]
35+
exclude: ^tests/

src/sitescanner/core/scanner.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import asyncio
44
import logging
55
from datetime import datetime
6+
from typing import Any
67
from uuid import uuid4
78

89
import aiohttp
@@ -46,7 +47,7 @@ def __init__(self, config: ScanConfig) -> None:
4647
self._session: aiohttp.ClientSession | None = None
4748

4849
# Initialize scanner modules
49-
self.scanners = {
50+
self.scanners: dict[str, Any] = {
5051
"sql_injection": SQLInjectionScanner(),
5152
"xss": XSSScanner(),
5253
"csrf": CSRFScanner(),
@@ -61,7 +62,9 @@ async def __aenter__(self) -> "Scanner":
6162
)
6263
return self
6364

64-
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
65+
async def __aexit__(
66+
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object
67+
) -> None:
6568
"""Async context manager exit."""
6669
if self._session:
6770
await self._session.close()
@@ -96,7 +99,7 @@ async def scan(self) -> ScanResult:
9699
# Run enabled scanners concurrently
97100
scan_tasks = []
98101
for scanner_name in self.config.enabled_scanners:
99-
if scanner_name in self.scanners:
102+
if scanner_name in self.scanners and self._session:
100103
scanner = self.scanners[scanner_name]
101104
scan_tasks.append(scanner.scan_pages(pages, self._session))
102105

src/sitescanner/scanners/csrf.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
import logging
5+
from typing import Any
56

67
import aiohttp
78
from bs4 import BeautifulSoup
@@ -44,14 +45,12 @@ def calculate_protection_level(self) -> str:
4445
]
4546
)
4647

47-
if protections == 0:
48-
return "none"
49-
elif protections == 1:
50-
return "weak"
51-
elif protections == 2:
52-
return "moderate"
53-
else:
54-
return "strong"
48+
protection_levels = {
49+
0: "none",
50+
1: "weak",
51+
2: "moderate",
52+
}
53+
return protection_levels.get(protections, "strong")
5554

5655

5756
class CSRFScanner:
@@ -134,7 +133,7 @@ async def _scan_page(self, url: str, session: aiohttp.ClientSession) -> list[Vul
134133

135134
return vulnerabilities
136135

137-
def _analyze_form(self, form: BeautifulSoup, page_url: str) -> CSRFTestCase:
136+
def _analyze_form(self, form: Any, page_url: str) -> CSRFTestCase:
138137
"""Analyze a form for CSRF protection using Pydantic validation.
139138
140139
Args:
@@ -165,8 +164,12 @@ def _analyze_form(self, form: BeautifulSoup, page_url: str) -> CSRFTestCase:
165164

166165
inputs = form.find_all(["input", "textarea", "select"])
167166
for input_field in inputs:
168-
field_name = input_field.get("name", "")
169-
field_value = input_field.get("value", "")
167+
field_name_raw = input_field.get("name", "")
168+
field_value_raw = input_field.get("value", "")
169+
170+
# Extract string values
171+
field_name = field_name_raw if isinstance(field_name_raw, str) else ""
172+
field_value = field_value_raw if isinstance(field_value_raw, str) else ""
170173

171174
if field_name:
172175
fields[field_name] = field_value
@@ -183,9 +186,9 @@ def _analyze_form(self, form: BeautifulSoup, page_url: str) -> CSRFTestCase:
183186
return CSRFTestCase(
184187
url=page_url,
185188
form_action=form_action if form_action else None,
186-
form_method=form_method
187-
if form_method in ["GET", "POST", "PUT", "DELETE", "PATCH"]
188-
else "POST",
189+
form_method=(
190+
form_method if form_method in ["GET", "POST", "PUT", "DELETE", "PATCH"] else "POST"
191+
),
189192
has_csrf_token=has_csrf,
190193
token_field_name=csrf_token_field,
191194
form_fields=fields,
@@ -195,8 +198,8 @@ def _analyze_form(self, form: BeautifulSoup, page_url: str) -> CSRFTestCase:
195198
def _check_csrf_protection(
196199
self,
197200
test_case: CSRFTestCase,
198-
headers: dict,
199-
cookies: dict,
201+
headers: dict[str, str],
202+
cookies: dict[str, Any],
200203
) -> CSRFProtectionCheck:
201204
"""Check for various CSRF protection mechanisms.
202205
@@ -212,7 +215,7 @@ def _check_csrf_protection(
212215

213216
# Check for SameSite cookie attribute
214217
for cookie in cookies.values():
215-
if hasattr(cookie, "get") and cookie.get("samesite"):
218+
if hasattr(cookie, "get") and callable(cookie.get) and cookie.get("samesite"):
216219
protection.has_samesite_cookie = True
217220
break
218221

0 commit comments

Comments
 (0)