Skip to content

Commit b1b1c39

Browse files
authored
Merge pull request #165 from Zac-HD/allocate-initial-structure
New hub and worker architecture
2 parents 98c4848 + 48f4828 commit b1b1c39

File tree

3 files changed

+306
-64
lines changed

3 files changed

+306
-64
lines changed

src/hypofuzz/entrypoint.py

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ def _debug_ranges_disabled() -> bool:
113113

114114

115115
def _fuzz_impl(n_processes: int, pytest_args: tuple[str, ...]) -> None:
116+
from hypofuzz.hypofuzz import FuzzWorkerHub
117+
116118
if sys.version_info[:2] >= (3, 12) and _debug_ranges_disabled():
117119
raise Exception(
118120
"The current python interpreter lacks position information for its "
@@ -132,7 +134,6 @@ def _fuzz_impl(n_processes: int, pytest_args: tuple[str, ...]) -> None:
132134
)
133135

134136
from hypofuzz.collection import collect_tests
135-
from hypofuzz.hypofuzz import _fuzz
136137

137138
# With our arguments validated, it's time to actually do the work.
138139
collection = collect_tests(pytest_args)
@@ -155,25 +156,11 @@ def _fuzz_impl(n_processes: int, pytest_args: tuple[str, ...]) -> None:
155156
f"test{tests_s}{skipped_msg}"
156157
)
157158

158-
if n_processes <= 1:
159-
_fuzz(pytest_args=pytest_args, nodeids=[t.nodeid for t in tests])
160-
else:
161-
processes: list[Process] = []
162-
for i in range(n_processes):
163-
# Round-robin for large test suites; all-on-all for tiny, etc.
164-
nodeids: set[str] = set()
165-
for ix in range(n_processes):
166-
nodeids.update(t.nodeid for t in tests[i + ix :: n_processes])
167-
if len(nodeids) >= 10: # enough to prioritize between
168-
break
169-
170-
p = Process(
171-
target=_fuzz,
172-
kwargs={"pytest_args": pytest_args, "nodeids": nodeids},
173-
)
174-
p.start()
175-
processes.append(p)
176-
for p in processes:
177-
p.join()
159+
hub = FuzzWorkerHub(
160+
nodeids=[t.nodeid for t in tests],
161+
pytest_args=pytest_args,
162+
n_processes=n_processes,
163+
)
164+
hub.start()
178165

179166
print("Found a failing input for every test!", file=sys.stderr)

src/hypofuzz/hypofuzz.py

Lines changed: 246 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22

33
import contextlib
44
import math
5+
import time
56
from collections import defaultdict
6-
from collections.abc import Callable
7+
from collections.abc import Callable, Mapping, Sequence
78
from contextlib import nullcontext
89
from functools import partial
10+
from multiprocessing import Manager, Process
911
from random import Random
1012
from typing import Any, Literal, Optional, Union
1113

@@ -124,7 +126,6 @@ def __init__(
124126
self.state: Optional[HypofuzzStateForActualGivenExecution] = None
125127
self.provider = HypofuzzProvider(None)
126128
self.stop_shrinking_at = math.inf
127-
self._fixturedefs: list[pytest.FixtureDef] = []
128129

129130
def _new_state(
130131
self, *, extra_kwargs: Optional[dict[str, Any]] = None
@@ -328,20 +329,86 @@ def has_found_failure(self) -> bool:
328329
return corpus is not None and bool(corpus.interesting_examples)
329330

330331

331-
class FuzzProcess:
332+
class FuzzWorker:
332333
"""
333-
Manages switching between several FuzzTargets, and managing their associated
334-
higher-level state, like setting up and tearing down pytest fixtures.
334+
Manages switching between several FuzzTargets, and also manages their
335+
associated higher-level state, like setting up and tearing down pytest
336+
fixtures.
335337
"""
336338

337-
def __init__(self, targets: list[FuzzTarget]) -> None:
339+
def __init__(
340+
self,
341+
*,
342+
pytest_args: Sequence[str],
343+
shared_state: Mapping,
344+
) -> None:
345+
self.pytest_args = pytest_args
346+
self.shared_state = shared_state
347+
338348
self.random = Random()
339-
self.targets = targets
349+
# The list of all collected fuzz targets. We collect this at the beginning
350+
# by running a pytest collection step.
351+
#
352+
# This is never modified or copied from after the initial collection.
353+
# When we need an actual target to fuzz, we create a new FuzzTarget
354+
# instance to put into self.targets.
355+
self.collected_targets: dict[str, FuzzTarget] = {}
356+
# the current pool of targets this worker can fuzz. This might change
357+
# based on directives from the hub.
358+
self.targets: dict[str, FuzzTarget] = {}
359+
# targets which we have previously started fuzzing, but have since been
360+
# told to drop by the hub. We keep the fuzz target in memory because we
361+
# might be told by the hub to pick this target up again in the future.
362+
#
363+
# When starting, dropping, and starting a target again, we cannot violate
364+
# the linear reports invariant that we do not write reports from the same
365+
# worker, on the same target, at two different fuzz campaigns for that
366+
# target. Once a worker starts fuzzing a target, it cannot restart fuzzing
367+
# that target from scratch without changing its uuid or wiping the previous
368+
# campaign, neither of which are feasible.
369+
self.dropped_targets: dict[str, FuzzTarget] = {}
340370

341371
self._current_target: Optional[FuzzTarget] = None
342372
self.event_dispatch: dict[bytes, list[FuzzTarget]] = defaultdict(list)
343-
for target in targets:
344-
self.event_dispatch[target.database_key].append(target)
373+
374+
def _add_target(self, nodeid: str) -> None:
375+
# if this target was previously dropped, move it from `dropped_targets`
376+
# to `targets`, without creating a new FuzzTarget.
377+
if nodeid in self.dropped_targets:
378+
target = self.dropped_targets[nodeid]
379+
self.targets[nodeid] = target
380+
del self.dropped_targets[nodeid]
381+
return
382+
383+
target = self.collected_targets[nodeid]
384+
# create a new FuzzTarget to put into self.targets, to avoid modifying
385+
# collected_fuzz_targets at all
386+
target = FuzzTarget(
387+
test_fn=target.test_fn,
388+
extra_kwargs=target.extra_kwargs,
389+
database=target.database,
390+
database_key=target.database_key,
391+
wrapped_test=target.wrapped_test,
392+
pytest_item=target.pytest_item,
393+
)
394+
assert nodeid not in self.targets
395+
self.targets[nodeid] = target
396+
self.event_dispatch[target.database_key].append(target)
397+
398+
def _remove_target(self, nodeid: str) -> None:
399+
target = self.targets[nodeid]
400+
del self.targets[nodeid]
401+
402+
assert nodeid not in self.dropped_targets
403+
self.dropped_targets[nodeid] = target
404+
# we intentionally do not remove our event_dispatch listener
405+
# here, because if we are ever told to pick up this dropped target
406+
# again in the future, we still want its corpus and failure replay
407+
# to be up to date from other workers.
408+
#
409+
# This is a tradeoff between memory usage and rescan time. It's
410+
# not clear what the optimal tradeoff strategy is. (purge dropped
411+
# targets after n large seconds?)
345412

346413
def on_event(self, listener_event: ListenerEventT) -> None:
347414
event = DatabaseEvent.from_event(listener_event)
@@ -354,7 +421,29 @@ def on_event(self, listener_event: ListenerEventT) -> None:
354421
@property
355422
def valid_targets(self) -> list[FuzzTarget]:
356423
# the targets we actually want to run/fuzz
357-
return [t for t in self.targets if not t.has_found_failure]
424+
return [t for t in self.targets.values() if not t.has_found_failure]
425+
426+
def _update_targets(self, nodeids: Sequence[str]) -> None:
427+
# Update our nodeids and targets with new directives from the hub.
428+
# * Nodes in both nodeids and self.targets are kept as-is
429+
# * Nodes in nodeids but not self.targets are added to our available
430+
# targets
431+
# * Nodes in self.targets but not nodeids are evicted from our targets.
432+
# These are nodes that the hub has decided are better to hand off to
433+
# another process.
434+
435+
# we get passed unique nodeids
436+
assert len(set(nodeids)) == len(nodeids)
437+
added_nodeids = set(nodeids) - set(self.targets.keys())
438+
removed_nodeids = set(self.targets.keys()) - set(nodeids)
439+
440+
for nodeid in added_nodeids:
441+
self._add_target(nodeid)
442+
443+
for nodeid in removed_nodeids:
444+
self._remove_target(nodeid)
445+
446+
assert set(self.targets.keys()) == set(nodeids)
358447

359448
def _switch_to_target(self, target: FuzzTarget) -> None:
360449
# if we're sticking with our current target, then we don't need to
@@ -371,44 +460,158 @@ def _switch_to_target(self, target: FuzzTarget) -> None:
371460
self._current_target = target
372461

373462
def start(self) -> None:
463+
self.worker_start = time.perf_counter()
464+
465+
collected = collect_tests(self.pytest_args)
466+
self.collected_targets = {
467+
target.nodeid: target for target in collected.fuzz_targets
468+
}
469+
374470
settings().database.add_listener(self.on_event)
375471

376472
while True:
377-
if not self.valid_targets:
378-
break
379-
380-
# choose the next target to fuzz with probability equal to the softmax
381-
# of its estimator. aka boltzmann exploration
382-
estimators = [behaviors_per_second(target) for target in self.valid_targets]
383-
estimators = softmax(estimators)
384-
# softmax might return 0.0 probability for some targets if there is
385-
# a substantial gap in estimator values (e.g. behaviors_per_second=1_000
386-
# vs behaviors_per_second=1.0). We don't expect this to happen normally,
387-
# but it might when our estimator state is just getting started.
388-
#
389-
# Mix in a uniform probability of 1%, so we will eventually get out of
390-
# such a hole.
391-
if self.random.random() < 0.01:
392-
target = self.random.choice(self.valid_targets)
393-
else:
394-
target = self.random.choices(
395-
self.valid_targets, weights=estimators, k=1
396-
)[0]
397-
398-
self._switch_to_target(target)
399-
# TODO we should scale this n up if our estimator expects that it will
400-
# take a long time to discover a new behavior, to reduce the overhead
401-
# of switching.
402-
for _ in range(100):
403-
target.run_one()
404-
405-
406-
def _fuzz(pytest_args: tuple[str, ...], nodeids: list[str]) -> None:
473+
self._update_targets(self.shared_state["hub_state"]["nodeids"])
474+
475+
# it's possible to go through an interim period where we have no nodeids,
476+
# but the hub still has nodeids to assign. We don't want the worker to
477+
# exit in this case, but rather keep waiting for nodeids. Even if n_workers
478+
# exceeds n_tests, we still want to keep all workers alive, because the
479+
# hub will assign the same test to multiple workers simultaneously.
480+
if self.valid_targets:
481+
# choose the next target to fuzz with probability equal to the softmax
482+
# of its estimator. aka boltzmann exploration
483+
estimators = [
484+
behaviors_per_second(target) for target in self.valid_targets
485+
]
486+
estimators = softmax(estimators)
487+
# softmax might return 0.0 probability for some targets if there is
488+
# a substantial gap in estimator values (e.g. behaviors_per_second=1_000
489+
# vs behaviors_per_second=1.0). We don't expect this to happen normally,
490+
# but it might when our estimator state is just getting started.
491+
#
492+
# Mix in a uniform probability of 1%, so we will eventually get out of
493+
# such a hole.
494+
if self.random.random() < 0.01:
495+
target = self.random.choice(self.valid_targets)
496+
else:
497+
target = self.random.choices(
498+
self.valid_targets, weights=estimators, k=1
499+
)[0]
500+
501+
self._switch_to_target(target)
502+
# TODO we should scale this n up if our estimator expects that it will
503+
# take a long time to discover a new behavior, to reduce the overhead
504+
# of switching targets.
505+
for _ in range(100):
506+
target.run_one()
507+
508+
worker_state = self.shared_state["worker_state"]
509+
worker_state["nodeids"][target.nodeid] = {
510+
"behavior_rates": None,
511+
}
512+
513+
# give the hub up-to-date estimator states
514+
current_lifetime = time.perf_counter() - self.worker_start
515+
worker_state = self.shared_state["worker_state"]
516+
worker_state["current_lifetime"] = current_lifetime
517+
worker_state["expected_lifetime"] = None
518+
519+
worker_state["valid_nodeids"] = [
520+
target.nodeid for target in self.valid_targets
521+
]
522+
523+
524+
class FuzzWorkerHub:
525+
def __init__(
526+
self,
527+
*,
528+
nodeids: Sequence[str],
529+
pytest_args: Sequence[str],
530+
n_processes: int,
531+
) -> None:
532+
self.nodeids = nodeids
533+
self.pytest_args = pytest_args
534+
self.n_processes = n_processes
535+
536+
self.shared_states: list[Mapping] = []
537+
538+
def start(self) -> None:
539+
processes: list[Process] = []
540+
541+
with Manager() as manager:
542+
for _ in range(self.n_processes):
543+
shared_state = manager.dict()
544+
shared_state["hub_state"] = manager.dict()
545+
shared_state["worker_state"] = manager.dict()
546+
shared_state["worker_state"]["nodeids"] = manager.dict()
547+
shared_state["worker_state"]["current_lifetime"] = 0.0
548+
shared_state["worker_state"]["expected_lifetime"] = 0.0
549+
shared_state["worker_state"]["valid_nodeids"] = manager.list()
550+
551+
process = Process(
552+
target=_start_worker,
553+
kwargs={
554+
"pytest_args": self.pytest_args,
555+
"shared_state": shared_state,
556+
},
557+
)
558+
processes.append(process)
559+
self.shared_states.append(shared_state)
560+
561+
# rebalance once at the start to put the initial node assignments
562+
# in the shared state
563+
self._rebalance()
564+
for process in processes:
565+
process.start()
566+
567+
while True:
568+
# rebalance automatically on an interval.
569+
# We may want to check some condition more frequently than this,
570+
# like "a process has no more nodes" (due to e.g. finding a
571+
# failure). So we rebalance either once every n seconds, or whenever
572+
# some worker needs a rebalancing.
573+
time.sleep(60)
574+
# if none of our workers have anything to do, we should exit as well
575+
if all(
576+
not state["worker_state"]["valid_nodeids"]
577+
for state in self.shared_states
578+
):
579+
break
580+
581+
self._rebalance()
582+
583+
def _rebalance(self) -> None:
584+
# rebalance the assignment of nodeids to workers, according to the
585+
# up-to-date estimators from our workers.
586+
# TODO actually defer starting up targets here, based on worker lifetime
587+
# and startup cost estimators. We should limit what we assign initially,
588+
# and only assign more as the estimator says it's worthwhile.
589+
590+
assert len(self.shared_states) == self.n_processes
591+
partitions = []
592+
for i in range(self.n_processes):
593+
# Round-robin for large test suites; all-on-all for tiny, etc.
594+
nodeids: set[str] = set()
595+
for ix in range(self.n_processes):
596+
nodeids.update(
597+
nodeid for nodeid in self.nodeids[i + ix :: self.n_processes]
598+
)
599+
if len(nodeids) >= 10: # enough to prioritize between
600+
break
601+
partitions.append(nodeids)
602+
603+
for state, nodeids in zip(self.shared_states, partitions):
604+
state["hub_state"]["nodeids"] = nodeids
605+
606+
607+
def _start_worker(
608+
pytest_args: Sequence[str],
609+
shared_state: Mapping,
610+
) -> None:
407611
"""Collect and fuzz tests.
408612
409613
Designed to be used inside a multiprocessing.Process started with the spawn()
410614
method - requires picklable arguments but works on Windows too.
411615
"""
412-
tests = [t for t in collect_tests(pytest_args).fuzz_targets if t.nodeid in nodeids]
413-
process = FuzzProcess(tests)
414-
process.start()
616+
worker = FuzzWorker(pytest_args=pytest_args, shared_state=shared_state)
617+
worker.start()

0 commit comments

Comments
 (0)