99# obtain one at https://mozilla.org/MPL/2.0/.
1010
1111import threading
12+ import warnings
1213from contextlib import contextmanager
1314
14- from hypothesis .errors import InvalidArgument
15- from hypothesis .internal .reflection import get_pretty_function_description
15+ from hypothesis .errors import HypothesisWarning , InvalidArgument
16+ from hypothesis .internal .reflection import (
17+ get_pretty_function_description ,
18+ is_identity_function ,
19+ )
1620from hypothesis .internal .validation import check_type
1721from hypothesis .strategies ._internal .strategies import (
1822 OneOfStrategy ,
@@ -72,24 +76,36 @@ def capped(self, max_templates):
7276
7377
7478class RecursiveStrategy (SearchStrategy ):
75- def __init__ (self , base , extend , max_leaves ):
79+ def __init__ (self , base , extend , min_leaves , max_leaves ):
7680 super ().__init__ ()
81+ self .min_leaves = min_leaves
7782 self .max_leaves = max_leaves
7883 self .base = base
7984 self .limited_base = LimitedStrategy (base )
8085 self .extend = extend
8186
87+ if is_identity_function (extend ):
88+ warnings .warn (
89+ "extend=lambda x: x is a no-op; you probably want to use a "
90+ "different extend function, or just use the base strategy directly." ,
91+ HypothesisWarning ,
92+ stacklevel = 5 ,
93+ )
94+
8295 strategies = [self .limited_base , self .extend (self .limited_base )]
8396 while 2 ** (len (strategies ) - 1 ) <= max_leaves :
8497 strategies .append (extend (OneOfStrategy (tuple (strategies ))))
98+ # If min_leaves > 1, we can never draw from base directly
99+ if min_leaves > 1 :
100+ strategies = strategies [1 :]
85101 self .strategy = OneOfStrategy (strategies )
86102
87103 def __repr__ (self ) -> str :
88104 if not hasattr (self , "_cached_repr" ):
89- self ._cached_repr = "recursive(%r, %s, max_leaves=%d)" % (
90- self .base ,
91- get_pretty_function_description (self .extend ),
92- self .max_leaves ,
105+ self ._cached_repr = (
106+ f"recursive( { self .base !r } , "
107+ f" { get_pretty_function_description (self .extend )} , "
108+ f"min_leaves= { self .min_leaves } , max_leaves= { self . max_leaves } )"
93109 )
94110 return self ._cached_repr
95111
@@ -99,20 +115,41 @@ def do_validate(self) -> None:
99115 check_strategy (extended , f"extend({ self .limited_base !r} )" )
100116 self .limited_base .validate ()
101117 extended .validate ()
118+ check_type (int , self .min_leaves , "min_leaves" )
102119 check_type (int , self .max_leaves , "max_leaves" )
120+ if self .min_leaves <= 0 :
121+ raise InvalidArgument (
122+ f"min_leaves={ self .min_leaves !r} must be greater than zero"
123+ )
103124 if self .max_leaves <= 0 :
104125 raise InvalidArgument (
105126 f"max_leaves={ self .max_leaves !r} must be greater than zero"
106127 )
128+ if self .min_leaves > self .max_leaves :
129+ raise InvalidArgument (
130+ f"min_leaves={ self .min_leaves !r} must be less than or equal to "
131+ f"max_leaves={ self .max_leaves !r} "
132+ )
107133
108134 def do_draw (self , data ):
109- count = 0
135+ min_leaves_retries = 0
110136 while True :
111137 try :
112138 with self .limited_base .capped (self .max_leaves ):
113- return data .draw (self .strategy )
139+ result = data .draw (self .strategy )
140+ leaves_drawn = self .max_leaves - self .limited_base .marker
141+ if leaves_drawn < self .min_leaves :
142+ data .events [
143+ f"Draw for { self !r} had fewer than "
144+ f"min_leaves={ self .min_leaves } and had to be retried"
145+ ] = ""
146+ min_leaves_retries += 1
147+ if min_leaves_retries < 5 :
148+ continue
149+ data .mark_invalid (f"min_leaves={ self .min_leaves } unsatisfied" )
150+ return result
114151 except LimitReached :
115- if count == 0 :
116- msg = f"Draw for { self !r} exceeded max_leaves and had to be retried "
117- data . events [ msg ] = " "
118- count += 1
152+ data . events [
153+ f"Draw for { self !r} exceeded "
154+ f"max_leaves= { self . max_leaves } and had to be retried "
155+ ] = ""
0 commit comments