Skip to content

Commit 36163e5

Browse files
committed
fix: harden parser, VM, recall validator and fix 4 bugs
- Rename TimeoutError alias to RecheckTimeoutError to avoid shadowing the built-in - Emit FAIL for unknown named backreferences instead of silently resolving to group 1 - Add context manager protocol to RecallValidator and close thread pool in HybridChecker/convenience functions to prevent resource leaks - Replace DFS with BFS in product-automaton cycle detection to avoid missing valid cycles through globally-visited intermediate nodes - Add parser depth limit, name/backref length bounds, multiline-aware anchors, atomic groups, possessive quantifiers, and real timeouts via ThreadPoolExecutor in recall validation - Fix ruff E402 import ordering in scc_checker and source_scanner - 874 tests pass (11 new in test_bugfixes.py)
1 parent 40d37b1 commit 36163e5

24 files changed

+690
-236
lines changed

src/recheck/cli.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
import argparse
44
import sys
55

6-
from recheck import check, Config
7-
from recheck.parser.flags import Flags
6+
from redoctor import check, Config
7+
from redoctor.diagnostics.diagnostics import Status
8+
from redoctor.parser.flags import Flags
89

910

1011
def main():
@@ -60,7 +61,7 @@ def main():
6061
"-q",
6162
"--quiet",
6263
action="store_true",
63-
help="Only output status (exit code 0=safe, 1=vulnerable, 2=error)",
64+
help="Only output status (exit code 0=safe, 1=vulnerable, 2=error, 3=both)",
6465
)
6566
parser.add_argument(
6667
"-v",
@@ -77,7 +78,7 @@ def main():
7778
args = parser.parse_args()
7879

7980
if args.version:
80-
from recheck import __version__
81+
from redoctor import __version__
8182

8283
print("recheck {0}".format(__version__))
8384
return 0
@@ -113,6 +114,17 @@ def main():
113114
try:
114115
result = check(pattern, flags=flags, config=config)
115116

117+
if result.status == Status.ERROR:
118+
has_error = True
119+
if not args.quiet:
120+
print(
121+
"ERROR: {0}: {1}".format(
122+
pattern, result.error or "analysis error"
123+
),
124+
file=sys.stderr,
125+
)
126+
continue
127+
116128
if args.quiet:
117129
# Quiet mode: just track status
118130
if result.is_vulnerable:
@@ -129,7 +141,7 @@ def main():
129141
print(" Pump: {0!r}".format(result.attack_pattern.pump))
130142
print(" Suffix: {0!r}".format(result.attack_pattern.suffix))
131143
# Generate example attack string
132-
attack = result.attack_pattern.build_attack_string(20)
144+
attack = result.attack_pattern.build(20)
133145
print(" Example: {0!r}".format(attack))
134146
if result.hotspot:
135147
print("Hotspot: {0}".format(result.hotspot))
@@ -141,7 +153,7 @@ def main():
141153
if result.complexity:
142154
print(" Complexity: {0}".format(result.complexity))
143155
if result.attack_pattern:
144-
attack = result.attack_pattern.build_attack_string(20)
156+
attack = result.attack_pattern.build(20)
145157
print(" Attack: {0!r}".format(attack))
146158
has_vulnerable = True
147159
elif result.is_safe:
@@ -155,7 +167,9 @@ def main():
155167
print("ERROR: {0}: {1}".format(pattern, e), file=sys.stderr)
156168

157169
# Return appropriate exit code
158-
if has_error:
170+
if has_error and has_vulnerable:
171+
return 3
172+
elif has_error:
159173
return 2
160174
elif has_vulnerable:
161175
return 1

src/redoctor/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
from redoctor.diagnostics.attack_pattern import AttackPattern
2323
from redoctor.diagnostics.hotspot import Hotspot
2424
from redoctor.parser.flags import Flags
25-
from redoctor.exceptions import RedoctorError, ParseError, TimeoutError
25+
from redoctor.exceptions import RedoctorError, ParseError, AnalysisTimeoutError
2626

27-
__version__ = "0.1.0"
27+
__version__ = "0.1.4"
2828

2929
__all__ = [
3030
# Main API
@@ -49,7 +49,7 @@
4949
# Exceptions
5050
"RedoctorError",
5151
"ParseError",
52-
"TimeoutError",
52+
"AnalysisTimeoutError",
5353
# Version
5454
"__version__",
5555
]

src/redoctor/automaton/complexity_analyzer.py

Lines changed: 18 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -115,95 +115,6 @@ def analyze(self) -> Tuple[Complexity, Optional[AmbiguityWitness]]:
115115

116116
return Complexity.safe(), None
117117

118-
def _build_multi_transition_witness(self) -> Optional[AmbiguityWitness]:
119-
"""Build a witness from multi-transition information.
120-
121-
Multi-transitions indicate that multiple epsilon paths from the same
122-
state lead to the same character transition. This is the pattern for
123-
nested quantifiers like (a+)+ where:
124-
- After reading 'a' and being at state S
125-
- We can continue the inner loop (one path)
126-
- Or exit inner and re-enter via outer loop (another path)
127-
"""
128-
if not self.ordered_nfa.multi_transitions:
129-
return None
130-
131-
# Get the first multi-transition
132-
for (from_state, to_state), count in self.ordered_nfa.multi_transitions.items():
133-
if count > 1:
134-
# Find a sample character for this transition
135-
sample_char = ord("a")
136-
for trans in self.ordered_nfa.get_transitions(from_state):
137-
if trans.target == to_state and trans.char:
138-
s = trans.char.sample()
139-
if s is not None:
140-
sample_char = s
141-
break
142-
143-
# Build prefix: path from initial to from_state
144-
prefix = self._find_path_to_state(from_state)
145-
146-
# Build pump: one character that uses the multi-transition
147-
pump = [sample_char]
148-
149-
# Build suffix: non-matching character
150-
suffix = [ord("!")]
151-
152-
return AmbiguityWitness(
153-
prefix=prefix,
154-
pump=pump,
155-
suffix=suffix,
156-
state1=from_state,
157-
state2=to_state,
158-
)
159-
return None
160-
161-
def _find_path_to_state(self, target: State) -> List[int]:
162-
"""Find a path from initial state to target state."""
163-
if self.ordered_nfa.initial is None:
164-
return []
165-
if self.ordered_nfa.initial == target:
166-
return []
167-
168-
visited: Set[State] = set()
169-
queue: deque[Tuple[State, List[int]]] = deque([(self.ordered_nfa.initial, [])])
170-
171-
while queue:
172-
state, path = queue.popleft()
173-
if state == target:
174-
return path
175-
if state in visited:
176-
continue
177-
visited.add(state)
178-
179-
for trans in self.ordered_nfa.get_transitions(state):
180-
if trans.char:
181-
sample = trans.char.sample()
182-
if sample is not None:
183-
queue.append((trans.target, path + [sample]))
184-
185-
return []
186-
187-
def _check_exponential_ambiguity_with_product(
188-
self,
189-
divergent_pairs: List[NFAStatePair],
190-
product_trans: Dict[NFAStatePair, List[Tuple[IChar, NFAStatePair]]],
191-
) -> Optional[AmbiguityWitness]:
192-
"""Check for EDA using pre-computed product automaton."""
193-
for pair in divergent_pairs:
194-
cycle = self._find_cycle_in_product(pair, product_trans)
195-
if cycle and len(cycle) > 0:
196-
prefix = self._find_path_to_pair(pair, product_trans)
197-
suffix = self._find_path_to_accepting(pair, product_trans)
198-
return AmbiguityWitness(
199-
prefix=prefix,
200-
pump=cycle,
201-
suffix=suffix,
202-
state1=pair.state1,
203-
state2=pair.state2,
204-
)
205-
return None
206-
207118
def _check_polynomial_ambiguity_with_product(
208119
self,
209120
divergent_pairs: List[NFAStatePair],
@@ -320,24 +231,35 @@ def _find_cycle_in_product(
320231
start: NFAStatePair,
321232
transitions: Dict[NFAStatePair, List[Tuple[IChar, NFAStatePair]]],
322233
) -> List[int]:
323-
"""Find a cycle starting and ending at the given pair."""
234+
"""Find a cycle starting and ending at the given pair.
235+
236+
Uses BFS from the immediate successors of *start* so that each
237+
intermediate node is visited at most once while still allowing
238+
multiple successors to independently search for a path back.
239+
"""
240+
# Seed the BFS with all direct successors of start
241+
queue: deque[Tuple[NFAStatePair, List[int]]] = deque()
242+
for char, next_pair in transitions.get(start, []):
243+
sample = char.sample()
244+
if sample is not None:
245+
queue.append((next_pair, [sample]))
246+
324247
visited: Set[NFAStatePair] = set()
325-
stack: List[Tuple[NFAStatePair, List[int]]] = [(start, [])]
326248

327-
while stack:
328-
pair, chars = stack.pop()
249+
while queue:
250+
pair, chars = queue.popleft()
329251

330-
if pair == start and chars:
252+
if pair == start:
331253
return chars
332254

333-
if pair in visited and pair != start:
255+
if pair in visited:
334256
continue
335257
visited.add(pair)
336258

337259
for char, next_pair in transitions.get(pair, []):
338260
sample = char.sample()
339261
if sample is not None:
340-
stack.append((next_pair, chars + [sample]))
262+
queue.append((next_pair, chars + [sample]))
341263

342264
return []
343265

src/redoctor/automaton/scc_checker.py

Lines changed: 11 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
4. IDA: Look for divergent chains between SCCs
88
"""
99

10+
import logging
1011
from dataclasses import dataclass
1112
from enum import Enum
1213
from typing import Dict, FrozenSet, List, Optional, Set, Tuple
@@ -17,6 +18,8 @@
1718
from redoctor.diagnostics.complexity import Complexity
1819
from redoctor.unicode.ichar import IChar
1920

21+
logger = logging.getLogger(__name__)
22+
2023

2124
class MatchMode(Enum):
2225
"""How the regex is expected to be used for matching.
@@ -217,8 +220,12 @@ def check(self) -> Tuple[Complexity, Optional[AmbiguityWitness]]:
217220

218221
return Complexity.safe(), None
219222

223+
except (ValueError, KeyError, IndexError):
224+
# Known non-critical exceptions from NFA construction/analysis
225+
logger.debug("SCC analysis failed with non-critical error", exc_info=True)
226+
return Complexity.safe(), None
220227
except Exception:
221-
# Any error during analysis - return safe to be conservative
228+
logger.error("Unexpected error during SCC analysis", exc_info=True)
222229
return Complexity.safe(), None
223230

224231
def _build_quick_witness(self) -> AmbiguityWitness:
@@ -400,10 +407,6 @@ def _check_eda_with_pair_graph(
400407
# Found EDA: diagonal -> off-diagonal -> diagonal
401408
return (scc, path + [next_char] if path else [next_char])
402409

403-
# Also check if next leads to a diagonal we already saw
404-
if next_pair in visited and next_pair[0] == next_pair[1]:
405-
return (scc, path + [next_char] if path else [next_char])
406-
407410
# Continue BFS
408411
for next_char, next_pair in pair_edges.get(current, []):
409412
if next_pair not in visited:
@@ -416,50 +419,10 @@ def _check_polynomial(
416419
) -> Optional[Tuple[int, List[Tuple[List[NFAState], List[NFAChar]]]]]:
417420
"""Check for IDA (Polynomial Degree of Ambiguity).
418421
419-
IDA exists when there's a chain of SCCs with divergence accumulating.
420-
The degree is the length of the longest such chain.
422+
Note: This is a stub. IDA detection is handled by the product
423+
automaton in ComplexityAnalyzer._check_polynomial_ambiguity_with_product.
421424
"""
422-
if not self.sccs or not self.graph:
423-
return None
424-
425-
# Compute the IDA degree for each SCC using dynamic programming
426-
scc_degrees: Dict[int, int] = {}
427-
scc_pumps: Dict[int, List[Tuple[List[NFAState], List[NFAChar]]]] = {}
428-
429-
# Sort SCCs topologically (reversed order of Tarjan's output)
430-
for i, scc in enumerate(self.sccs):
431-
if self.graph.is_atom(scc):
432-
scc_degrees[i] = 0
433-
scc_pumps[i] = []
434-
else:
435-
scc_degrees[i] = 1
436-
scc_pumps[i] = []
437-
438-
# Check for IDA chains between SCCs
439-
# (This is a simplified version - full implementation would need G3 graph)
440-
max_degree = max(scc_degrees.values()) if scc_degrees else 0
441-
442-
if max_degree <= 1:
443-
return None
444-
445-
# Collect pumps for the max degree chain
446-
pumps: List[Tuple[List[NFAState], List[NFAChar]]] = []
447-
for i, degree in scc_degrees.items():
448-
if degree == max_degree:
449-
scc = self.sccs[i]
450-
# Get a sample char from this SCC
451-
sample_chars: List[NFAChar] = []
452-
for state in scc:
453-
for char, target in self.graph.neighbors.get(state, []):
454-
if target in set(scc):
455-
sample_chars.append(char)
456-
break
457-
if sample_chars:
458-
break
459-
pumps.append((scc, sample_chars))
460-
break
461-
462-
return (max_degree, pumps) if pumps else None
425+
return None
463426

464427
def _build_witness(
465428
self, eda_result: Tuple[List[NFAState], List[NFAChar]]

src/redoctor/checker.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ def __init__(self, config: Config = None):
3030
timeout=self.config.recall_timeout,
3131
)
3232

33+
def close(self) -> None:
34+
"""Release resources held by the checker."""
35+
self.validator.close()
36+
3337
def check(self, pattern: str, flags: Flags = None) -> Diagnostics:
3438
"""Check a regex pattern for ReDoS vulnerabilities.
3539
@@ -72,7 +76,7 @@ def check_pattern(self, pattern: Pattern) -> Diagnostics:
7276
# If automaton checker returns unknown, use fuzz as fallback
7377
if result.status == Status.UNKNOWN:
7478
fuzz_result = self.fuzz_checker.check_pattern(pattern)
75-
if fuzz_result.is_vulnerable:
79+
if fuzz_result.is_vulnerable or fuzz_result.is_safe:
7680
result = fuzz_result
7781

7882
# If automaton says SAFE, trust it.
@@ -152,14 +156,17 @@ def check(pattern: str, flags: Flags = None, config: Config = None) -> Diagnosti
152156
Diagnostics result with vulnerability information.
153157
154158
Example:
155-
>>> from recheck import check
159+
>>> from redoctor import check
156160
>>> result = check(r"^(a+)+$")
157161
>>> if result.is_vulnerable:
158162
... print(f"Vulnerable: {result.complexity}")
159163
... print(f"Attack: {result.attack}")
160164
"""
161165
checker = HybridChecker(config)
162-
return checker.check(pattern, flags)
166+
try:
167+
return checker.check(pattern, flags)
168+
finally:
169+
checker.close()
163170

164171

165172
def check_pattern(pattern: Pattern, config: Config = None) -> Diagnostics:
@@ -173,7 +180,10 @@ def check_pattern(pattern: Pattern, config: Config = None) -> Diagnostics:
173180
Diagnostics result.
174181
"""
175182
checker = HybridChecker(config)
176-
return checker.check_pattern(pattern)
183+
try:
184+
return checker.check_pattern(pattern)
185+
finally:
186+
checker.close()
177187

178188

179189
def is_safe(pattern: str, flags: Flags = None, config: Config = None) -> bool:

0 commit comments

Comments
 (0)