11import threading
22from collections import defaultdict
3- from functools import lru_cache
3+ from collections . abc import Sequence
44from queue import Empty , Queue
5- from typing import Any , Literal , Optional
5+ from typing import TYPE_CHECKING , Any , Literal , Optional
66
7- from hypothesis .extra ._patching import (
8- get_patch_for as _get_patch_for ,
9- make_patch as _make_patch ,
10- )
7+ from hypothesis .extra ._patching import get_patch_for , make_patch as _make_patch
8+ from sortedcontainers import SortedList
119
1210from hypofuzz import __version__
1311from hypofuzz .database import Observation
1412
15- COVERING_VIA = "covering example"
16- FAILING_VIA = "discovered failure"
13+ if TYPE_CHECKING :
14+ from typing import TypeAlias
15+
16+ # we have a two tiered structure.
17+ # * First, we store the list of test case reprs corresponding to the list of
18+ # @examples.
19+ # * Each time we add a new such input, we compute the new patch for the entire
20+ # list.
21+
1722# nodeid: {
18- # "covering": [(fname, before, after), ... ],
19- # "failing": [(fname, before, after), ... ],
23+ # "covering": list[observation.representation ],
24+ # "failing": list[observation.representation ],
2025# }
21- # TODO this duplicates the test function contents in `before` and `after`,
22- # we probably want a more memory-efficient representation eventually
23- # (and a smaller win: map fname to a list of (before, after), instead of storing
24- # each fname)
25- PATCHES : dict [str , dict [str , list [tuple [str , str , str ]]]] = defaultdict (
26- lambda : {"covering" : [], "failing" : []}
26+ #
27+ # We sort by string length, as a heuristic for putting simpler examples first in
28+ # the patch.
29+ EXAMPLES : dict [str , dict [str , SortedList [str ]]] = defaultdict (
30+ lambda : {"covering" : SortedList (key = len ), "failing" : SortedList (key = len )}
2731)
28- get_patch_for = lru_cache (maxsize = 8192 )(_get_patch_for )
29-
30- _queue : Queue = Queue ()
32+ # nodeid: {
33+ # "covering": patch,
34+ # "failing": patch,
35+ # }
36+ PATCHES : dict [str , dict [str , Optional [str ]]] = defaultdict (
37+ lambda : {"covering" : None , "failing" : None }
38+ )
39+ VIA = {"covering" : "covering example" , "failing" : "discovered failure" }
40+ COMMIT_MESSAGE = {
41+ "covering" : "add covering examples" ,
42+ "failing" : "add failing examples" ,
43+ }
44+
45+ ObservationTypeT : "TypeAlias" = Literal ["covering" , "failing" ]
46+ _queue : Queue [tuple [Any , str , Observation , ObservationTypeT ]] = Queue ()
3147_thread : Optional [threading .Thread ] = None
3248
3349
@@ -36,51 +52,45 @@ def add_patch(
3652 test_function : Any ,
3753 nodeid : str ,
3854 observation : Observation ,
39- observation_type : Literal [ "covering" , "failing" ] ,
55+ observation_type : ObservationTypeT ,
4056) -> None :
4157 _queue .put ((test_function , nodeid , observation , observation_type ))
4258
4359
44- @lru_cache (maxsize = 1024 )
45- def make_patch (triples : tuple [tuple [str , str , str ]], * , msg : str ) -> str :
60+ def make_patch (
61+ function : Any , examples : Sequence [str ], observation_type : ObservationTypeT
62+ ) -> Optional [str ]:
63+ via = VIA [observation_type ]
64+ triple = get_patch_for (function , examples = [(example , via ) for example in examples ])
65+ if triple is None :
66+ return None
67+
68+ commit_message = COMMIT_MESSAGE [observation_type ]
4669 return _make_patch (
47- triples ,
48- msg = msg ,
70+ ( triple ,) ,
71+ msg = commit_message ,
4972 author = f"HypoFuzz { __version__ } <no-reply@hypofuzz.com>" ,
5073 )
5174
5275
53- def failing_patch (nodeid : str ) -> Optional [str ]:
54- failing = PATCHES [nodeid ]["failing" ]
55- return make_patch (tuple (failing ), msg = "add failing examples" ) if failing else None
56-
57-
58- def covering_patch (nodeid : str ) -> Optional [str ]:
59- covering = PATCHES [nodeid ]["covering" ]
60- return (
61- make_patch (tuple (covering ), msg = "add covering examples" ) if covering else None
62- )
63-
64-
6576def _worker () -> None :
77+ # TODO We might optimize this by checking each function ahead of time for known
78+ # reasons why a patch would fail, for instance using st.data in the signature,
79+ # and then early-returning here before calling get_patch_for.
6680 while True :
6781 try :
68- item = _queue .get (timeout = 1.0 )
82+ test_function , nodeid , observation , observation_type = _queue .get (
83+ timeout = 1.0
84+ )
6985 except Empty :
7086 continue
7187
72- test_function , nodeid , observation , observation_type = item
73-
74- via = COVERING_VIA if observation_type == "covering" else FAILING_VIA
75- # If this thread ends up using significant resources, we might optimize
76- # this by checking each function ahead of time for known reasons why a
77- # patch would fail, for instance using st.data in the signature, and then
78- # simply discarding those here entirely.
79- patch = get_patch_for (
80- test_function , ((observation .representation , via ),), strip_via = via
88+ examples = EXAMPLES [nodeid ][observation_type ]
89+ examples .add (observation .representation )
90+ PATCHES [nodeid ][observation_type ] = make_patch (
91+ test_function , examples , observation_type
8192 )
82- if patch is not None :
83- PATCHES [nodeid ][observation_type ].append (patch )
93+
8494 _queue .task_done ()
8595
8696
@@ -90,3 +100,11 @@ def start_patching_thread() -> None:
90100
91101 _thread = threading .Thread (target = _worker , daemon = True )
92102 _thread .start ()
103+
104+
105+ def failing_patch (nodeid : str ) -> Optional [str ]:
106+ return PATCHES [nodeid ]["failing" ]
107+
108+
109+ def covering_patch (nodeid : str ) -> Optional [str ]:
110+ return PATCHES [nodeid ]["covering" ]
0 commit comments