diff --git a/src/hypofuzz/entrypoint.py b/src/hypofuzz/entrypoint.py index e21b45c6..b14adfc6 100644 --- a/src/hypofuzz/entrypoint.py +++ b/src/hypofuzz/entrypoint.py @@ -113,6 +113,8 @@ def _debug_ranges_disabled() -> bool: def _fuzz_impl(n_processes: int, pytest_args: tuple[str, ...]) -> None: + from hypofuzz.hypofuzz import FuzzWorkerHub + if sys.version_info[:2] >= (3, 12) and _debug_ranges_disabled(): raise Exception( "The current python interpreter lacks position information for its " @@ -132,7 +134,6 @@ def _fuzz_impl(n_processes: int, pytest_args: tuple[str, ...]) -> None: ) from hypofuzz.collection import collect_tests - from hypofuzz.hypofuzz import _fuzz # With our arguments validated, it's time to actually do the work. collection = collect_tests(pytest_args) @@ -155,25 +156,11 @@ def _fuzz_impl(n_processes: int, pytest_args: tuple[str, ...]) -> None: f"test{tests_s}{skipped_msg}" ) - if n_processes <= 1: - _fuzz(pytest_args=pytest_args, nodeids=[t.nodeid for t in tests]) - else: - processes: list[Process] = [] - for i in range(n_processes): - # Round-robin for large test suites; all-on-all for tiny, etc. - nodeids: set[str] = set() - for ix in range(n_processes): - nodeids.update(t.nodeid for t in tests[i + ix :: n_processes]) - if len(nodeids) >= 10: # enough to prioritize between - break - - p = Process( - target=_fuzz, - kwargs={"pytest_args": pytest_args, "nodeids": nodeids}, - ) - p.start() - processes.append(p) - for p in processes: - p.join() + hub = FuzzWorkerHub( + nodeids=[t.nodeid for t in tests], + pytest_args=pytest_args, + n_processes=n_processes, + ) + hub.start() print("Found a failing input for every test!", file=sys.stderr) diff --git a/src/hypofuzz/hypofuzz.py b/src/hypofuzz/hypofuzz.py index 89a194db..4b7d3d5e 100644 --- a/src/hypofuzz/hypofuzz.py +++ b/src/hypofuzz/hypofuzz.py @@ -2,10 +2,12 @@ import contextlib import math +import time from collections import defaultdict -from collections.abc import Callable +from collections.abc import Callable, Mapping, Sequence from contextlib import nullcontext from functools import partial +from multiprocessing import Manager, Process from random import Random from typing import Any, Literal, Optional, Union @@ -124,7 +126,6 @@ def __init__( self.state: Optional[HypofuzzStateForActualGivenExecution] = None self.provider = HypofuzzProvider(None) self.stop_shrinking_at = math.inf - self._fixturedefs: list[pytest.FixtureDef] = [] def _new_state( self, *, extra_kwargs: Optional[dict[str, Any]] = None @@ -328,20 +329,86 @@ def has_found_failure(self) -> bool: return corpus is not None and bool(corpus.interesting_examples) -class FuzzProcess: +class FuzzWorker: """ - Manages switching between several FuzzTargets, and managing their associated - higher-level state, like setting up and tearing down pytest fixtures. + Manages switching between several FuzzTargets, and also manages their + associated higher-level state, like setting up and tearing down pytest + fixtures. """ - def __init__(self, targets: list[FuzzTarget]) -> None: + def __init__( + self, + *, + pytest_args: Sequence[str], + shared_state: Mapping, + ) -> None: + self.pytest_args = pytest_args + self.shared_state = shared_state + self.random = Random() - self.targets = targets + # The list of all collected fuzz targets. We collect this at the beginning + # by running a pytest collection step. + # + # This is never modified or copied from after the initial collection. + # When we need an actual target to fuzz, we create a new FuzzTarget + # instance to put into self.targets. + self.collected_targets: dict[str, FuzzTarget] = {} + # the current pool of targets this worker can fuzz. This might change + # based on directives from the hub. + self.targets: dict[str, FuzzTarget] = {} + # targets which we have previously started fuzzing, but have since been + # told to drop by the hub. We keep the fuzz target in memory because we + # might be told by the hub to pick this target up again in the future. + # + # When starting, dropping, and starting a target again, we cannot violate + # the linear reports invariant that we do not write reports from the same + # worker, on the same target, at two different fuzz campaigns for that + # target. Once a worker starts fuzzing a target, it cannot restart fuzzing + # that target from scratch without changing its uuid or wiping the previous + # campaign, neither of which are feasible. + self.dropped_targets: dict[str, FuzzTarget] = {} self._current_target: Optional[FuzzTarget] = None self.event_dispatch: dict[bytes, list[FuzzTarget]] = defaultdict(list) - for target in targets: - self.event_dispatch[target.database_key].append(target) + + def _add_target(self, nodeid: str) -> None: + # if this target was previously dropped, move it from `dropped_targets` + # to `targets`, without creating a new FuzzTarget. + if nodeid in self.dropped_targets: + target = self.dropped_targets[nodeid] + self.targets[nodeid] = target + del self.dropped_targets[nodeid] + return + + target = self.collected_targets[nodeid] + # create a new FuzzTarget to put into self.targets, to avoid modifying + # collected_fuzz_targets at all + target = FuzzTarget( + test_fn=target.test_fn, + extra_kwargs=target.extra_kwargs, + database=target.database, + database_key=target.database_key, + wrapped_test=target.wrapped_test, + pytest_item=target.pytest_item, + ) + assert nodeid not in self.targets + self.targets[nodeid] = target + self.event_dispatch[target.database_key].append(target) + + def _remove_target(self, nodeid: str) -> None: + target = self.targets[nodeid] + del self.targets[nodeid] + + assert nodeid not in self.dropped_targets + self.dropped_targets[nodeid] = target + # we intentionally do not remove our event_dispatch listener + # here, because if we are ever told to pick up this dropped target + # again in the future, we still want its corpus and failure replay + # to be up to date from other workers. + # + # This is a tradeoff between memory usage and rescan time. It's + # not clear what the optimal tradeoff strategy is. (purge dropped + # targets after n large seconds?) def on_event(self, listener_event: ListenerEventT) -> None: event = DatabaseEvent.from_event(listener_event) @@ -354,7 +421,29 @@ def on_event(self, listener_event: ListenerEventT) -> None: @property def valid_targets(self) -> list[FuzzTarget]: # the targets we actually want to run/fuzz - return [t for t in self.targets if not t.has_found_failure] + return [t for t in self.targets.values() if not t.has_found_failure] + + def _update_targets(self, nodeids: Sequence[str]) -> None: + # Update our nodeids and targets with new directives from the hub. + # * Nodes in both nodeids and self.targets are kept as-is + # * Nodes in nodeids but not self.targets are added to our available + # targets + # * Nodes in self.targets but not nodeids are evicted from our targets. + # These are nodes that the hub has decided are better to hand off to + # another process. + + # we get passed unique nodeids + assert len(set(nodeids)) == len(nodeids) + added_nodeids = set(nodeids) - set(self.targets.keys()) + removed_nodeids = set(self.targets.keys()) - set(nodeids) + + for nodeid in added_nodeids: + self._add_target(nodeid) + + for nodeid in removed_nodeids: + self._remove_target(nodeid) + + assert set(self.targets.keys()) == set(nodeids) def _switch_to_target(self, target: FuzzTarget) -> None: # if we're sticking with our current target, then we don't need to @@ -371,44 +460,158 @@ def _switch_to_target(self, target: FuzzTarget) -> None: self._current_target = target def start(self) -> None: + self.worker_start = time.perf_counter() + + collected = collect_tests(self.pytest_args) + self.collected_targets = { + target.nodeid: target for target in collected.fuzz_targets + } + settings().database.add_listener(self.on_event) while True: - if not self.valid_targets: - break - - # choose the next target to fuzz with probability equal to the softmax - # of its estimator. aka boltzmann exploration - estimators = [behaviors_per_second(target) for target in self.valid_targets] - estimators = softmax(estimators) - # softmax might return 0.0 probability for some targets if there is - # a substantial gap in estimator values (e.g. behaviors_per_second=1_000 - # vs behaviors_per_second=1.0). We don't expect this to happen normally, - # but it might when our estimator state is just getting started. - # - # Mix in a uniform probability of 1%, so we will eventually get out of - # such a hole. - if self.random.random() < 0.01: - target = self.random.choice(self.valid_targets) - else: - target = self.random.choices( - self.valid_targets, weights=estimators, k=1 - )[0] - - self._switch_to_target(target) - # TODO we should scale this n up if our estimator expects that it will - # take a long time to discover a new behavior, to reduce the overhead - # of switching. - for _ in range(100): - target.run_one() - - -def _fuzz(pytest_args: tuple[str, ...], nodeids: list[str]) -> None: + self._update_targets(self.shared_state["hub_state"]["nodeids"]) + + # it's possible to go through an interim period where we have no nodeids, + # but the hub still has nodeids to assign. We don't want the worker to + # exit in this case, but rather keep waiting for nodeids. Even if n_workers + # exceeds n_tests, we still want to keep all workers alive, because the + # hub will assign the same test to multiple workers simultaneously. + if self.valid_targets: + # choose the next target to fuzz with probability equal to the softmax + # of its estimator. aka boltzmann exploration + estimators = [ + behaviors_per_second(target) for target in self.valid_targets + ] + estimators = softmax(estimators) + # softmax might return 0.0 probability for some targets if there is + # a substantial gap in estimator values (e.g. behaviors_per_second=1_000 + # vs behaviors_per_second=1.0). We don't expect this to happen normally, + # but it might when our estimator state is just getting started. + # + # Mix in a uniform probability of 1%, so we will eventually get out of + # such a hole. + if self.random.random() < 0.01: + target = self.random.choice(self.valid_targets) + else: + target = self.random.choices( + self.valid_targets, weights=estimators, k=1 + )[0] + + self._switch_to_target(target) + # TODO we should scale this n up if our estimator expects that it will + # take a long time to discover a new behavior, to reduce the overhead + # of switching targets. + for _ in range(100): + target.run_one() + + worker_state = self.shared_state["worker_state"] + worker_state["nodeids"][target.nodeid] = { + "behavior_rates": None, + } + + # give the hub up-to-date estimator states + current_lifetime = time.perf_counter() - self.worker_start + worker_state = self.shared_state["worker_state"] + worker_state["current_lifetime"] = current_lifetime + worker_state["expected_lifetime"] = None + + worker_state["valid_nodeids"] = [ + target.nodeid for target in self.valid_targets + ] + + +class FuzzWorkerHub: + def __init__( + self, + *, + nodeids: Sequence[str], + pytest_args: Sequence[str], + n_processes: int, + ) -> None: + self.nodeids = nodeids + self.pytest_args = pytest_args + self.n_processes = n_processes + + self.shared_states: list[Mapping] = [] + + def start(self) -> None: + processes: list[Process] = [] + + with Manager() as manager: + for _ in range(self.n_processes): + shared_state = manager.dict() + shared_state["hub_state"] = manager.dict() + shared_state["worker_state"] = manager.dict() + shared_state["worker_state"]["nodeids"] = manager.dict() + shared_state["worker_state"]["current_lifetime"] = 0.0 + shared_state["worker_state"]["expected_lifetime"] = 0.0 + shared_state["worker_state"]["valid_nodeids"] = manager.list() + + process = Process( + target=_start_worker, + kwargs={ + "pytest_args": self.pytest_args, + "shared_state": shared_state, + }, + ) + processes.append(process) + self.shared_states.append(shared_state) + + # rebalance once at the start to put the initial node assignments + # in the shared state + self._rebalance() + for process in processes: + process.start() + + while True: + # rebalance automatically on an interval. + # We may want to check some condition more frequently than this, + # like "a process has no more nodes" (due to e.g. finding a + # failure). So we rebalance either once every n seconds, or whenever + # some worker needs a rebalancing. + time.sleep(60) + # if none of our workers have anything to do, we should exit as well + if all( + not state["worker_state"]["valid_nodeids"] + for state in self.shared_states + ): + break + + self._rebalance() + + def _rebalance(self) -> None: + # rebalance the assignment of nodeids to workers, according to the + # up-to-date estimators from our workers. + # TODO actually defer starting up targets here, based on worker lifetime + # and startup cost estimators. We should limit what we assign initially, + # and only assign more as the estimator says it's worthwhile. + + assert len(self.shared_states) == self.n_processes + partitions = [] + for i in range(self.n_processes): + # Round-robin for large test suites; all-on-all for tiny, etc. + nodeids: set[str] = set() + for ix in range(self.n_processes): + nodeids.update( + nodeid for nodeid in self.nodeids[i + ix :: self.n_processes] + ) + if len(nodeids) >= 10: # enough to prioritize between + break + partitions.append(nodeids) + + for state, nodeids in zip(self.shared_states, partitions): + state["hub_state"]["nodeids"] = nodeids + + +def _start_worker( + pytest_args: Sequence[str], + shared_state: Mapping, +) -> None: """Collect and fuzz tests. Designed to be used inside a multiprocessing.Process started with the spawn() method - requires picklable arguments but works on Windows too. """ - tests = [t for t in collect_tests(pytest_args).fuzz_targets if t.nodeid in nodeids] - process = FuzzProcess(tests) - process.start() + worker = FuzzWorker(pytest_args=pytest_args, shared_state=shared_state) + worker.start() diff --git a/tests/test_workers.py b/tests/test_workers.py new file mode 100644 index 00000000..189abbf3 --- /dev/null +++ b/tests/test_workers.py @@ -0,0 +1,52 @@ +import multiprocessing +from multiprocessing import Process + +from common import setup_test_code, wait_for + +from hypofuzz.hypofuzz import _start_worker + +test_code = """ +@given(st.integers()) +def test_a(n): + pass + +@given(st.integers()) +def test_b(): + pass + +@given(st.integers()) +def test_c(): + pass +""" + + +def test_workers(tmp_path): + test_dir, _db_dir = setup_test_code(tmp_path, test_code) + + with multiprocessing.Manager() as manager: + shared_state = manager.dict() + shared_state["hub_state"] = manager.dict() + shared_state["hub_state"]["nodeids"] = [] + shared_state["worker_state"] = manager.dict() + shared_state["worker_state"]["nodeids"] = manager.dict() + shared_state["worker_state"]["valid_nodeids"] = manager.list() + shared_state["worker_state"]["current_lifetime"] = 0.0 + shared_state["worker_state"]["expected_lifetime"] = 0.0 + process = Process( + target=_start_worker, + kwargs={"pytest_args": [str(test_dir)], "shared_state": shared_state}, + ) + process.start() + + assert shared_state["hub_state"]["nodeids"] == [] + + shared_state["hub_state"]["nodeids"] = ["test_a.py::test_a"] + wait_for( + lambda: shared_state["worker_state"]["valid_nodeids"] + == ["test_a.py::test_a"], + interval=0.01, + ) + assert shared_state["worker_state"]["current_lifetime"] > 0.0 + + process.kill() + process.join()