Skip to content

Commit e5f9919

Browse files
authored
Merge pull request #4616 from Zac-HD/claude/add-min-leaves-parameter-01XYPfenzf6xuY64W4tMDpAT
Add `min_leaves` parameter to `st.recursive()`
2 parents b644550 + 1458abe commit e5f9919

File tree

4 files changed

+105
-18
lines changed

4 files changed

+105
-18
lines changed

hypothesis-python/RELEASE.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
RELEASE_TYPE: minor
2+
3+
This release adds a ``min_leaves`` argument to :func:`~hypothesis.strategies.recursive`,
4+
which ensures that generated recursive structures have at least the specified number
5+
of leaf nodes (:issue:`4205`).

hypothesis-python/src/hypothesis/strategies/_internal/core.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1857,12 +1857,16 @@ def recursive(
18571857
base: SearchStrategy[Ex],
18581858
extend: Callable[[SearchStrategy[Any]], SearchStrategy[T]],
18591859
*,
1860+
min_leaves: int = 1,
18601861
max_leaves: int = 100,
18611862
) -> SearchStrategy[T | Ex]:
18621863
"""base: A strategy to start from.
18631864
18641865
extend: A function which takes a strategy and returns a new strategy.
18651866
1867+
min_leaves: The minimum number of elements to be drawn from base on a given
1868+
run.
1869+
18661870
max_leaves: The maximum number of elements to be drawn from base on a given
18671871
run.
18681872
@@ -1881,7 +1885,7 @@ def recursive(
18811885
18821886
"""
18831887

1884-
return RecursiveStrategy(base, extend, max_leaves)
1888+
return RecursiveStrategy(base, extend, min_leaves, max_leaves)
18851889

18861890

18871891
class PermutationStrategy(SearchStrategy):

hypothesis-python/src/hypothesis/strategies/_internal/recursive.py

Lines changed: 50 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,14 @@
99
# obtain one at https://mozilla.org/MPL/2.0/.
1010

1111
import threading
12+
import warnings
1213
from contextlib import contextmanager
1314

14-
from hypothesis.errors import InvalidArgument
15-
from hypothesis.internal.reflection import get_pretty_function_description
15+
from hypothesis.errors import HypothesisWarning, InvalidArgument
16+
from hypothesis.internal.reflection import (
17+
get_pretty_function_description,
18+
is_identity_function,
19+
)
1620
from hypothesis.internal.validation import check_type
1721
from hypothesis.strategies._internal.strategies import (
1822
OneOfStrategy,
@@ -72,24 +76,36 @@ def capped(self, max_templates):
7276

7377

7478
class RecursiveStrategy(SearchStrategy):
75-
def __init__(self, base, extend, max_leaves):
79+
def __init__(self, base, extend, min_leaves, max_leaves):
7680
super().__init__()
81+
self.min_leaves = min_leaves
7782
self.max_leaves = max_leaves
7883
self.base = base
7984
self.limited_base = LimitedStrategy(base)
8085
self.extend = extend
8186

87+
if is_identity_function(extend):
88+
warnings.warn(
89+
"extend=lambda x: x is a no-op; you probably want to use a "
90+
"different extend function, or just use the base strategy directly.",
91+
HypothesisWarning,
92+
stacklevel=5,
93+
)
94+
8295
strategies = [self.limited_base, self.extend(self.limited_base)]
8396
while 2 ** (len(strategies) - 1) <= max_leaves:
8497
strategies.append(extend(OneOfStrategy(tuple(strategies))))
98+
# If min_leaves > 1, we can never draw from base directly
99+
if min_leaves > 1:
100+
strategies = strategies[1:]
85101
self.strategy = OneOfStrategy(strategies)
86102

87103
def __repr__(self) -> str:
88104
if not hasattr(self, "_cached_repr"):
89-
self._cached_repr = "recursive(%r, %s, max_leaves=%d)" % (
90-
self.base,
91-
get_pretty_function_description(self.extend),
92-
self.max_leaves,
105+
self._cached_repr = (
106+
f"recursive({self.base!r}, "
107+
f"{get_pretty_function_description(self.extend)}, "
108+
f"min_leaves={self.min_leaves}, max_leaves={self.max_leaves})"
93109
)
94110
return self._cached_repr
95111

@@ -99,20 +115,41 @@ def do_validate(self) -> None:
99115
check_strategy(extended, f"extend({self.limited_base!r})")
100116
self.limited_base.validate()
101117
extended.validate()
118+
check_type(int, self.min_leaves, "min_leaves")
102119
check_type(int, self.max_leaves, "max_leaves")
120+
if self.min_leaves <= 0:
121+
raise InvalidArgument(
122+
f"min_leaves={self.min_leaves!r} must be greater than zero"
123+
)
103124
if self.max_leaves <= 0:
104125
raise InvalidArgument(
105126
f"max_leaves={self.max_leaves!r} must be greater than zero"
106127
)
128+
if self.min_leaves > self.max_leaves:
129+
raise InvalidArgument(
130+
f"min_leaves={self.min_leaves!r} must be less than or equal to "
131+
f"max_leaves={self.max_leaves!r}"
132+
)
107133

108134
def do_draw(self, data):
109-
count = 0
135+
min_leaves_retries = 0
110136
while True:
111137
try:
112138
with self.limited_base.capped(self.max_leaves):
113-
return data.draw(self.strategy)
139+
result = data.draw(self.strategy)
140+
leaves_drawn = self.max_leaves - self.limited_base.marker
141+
if leaves_drawn < self.min_leaves:
142+
data.events[
143+
f"Draw for {self!r} had fewer than "
144+
f"min_leaves={self.min_leaves} and had to be retried"
145+
] = ""
146+
min_leaves_retries += 1
147+
if min_leaves_retries < 5:
148+
continue
149+
data.mark_invalid(f"min_leaves={self.min_leaves} unsatisfied")
150+
return result
114151
except LimitReached:
115-
if count == 0:
116-
msg = f"Draw for {self!r} exceeded max_leaves and had to be retried"
117-
data.events[msg] = ""
118-
count += 1
152+
data.events[
153+
f"Draw for {self!r} exceeded "
154+
f"max_leaves={self.max_leaves} and had to be retried"
155+
] = ""

hypothesis-python/tests/cover/test_recursive.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,15 @@
1010

1111
import pytest
1212

13-
from hypothesis import given, strategies as st
14-
from hypothesis.errors import InvalidArgument
15-
16-
from tests.common.debug import check_can_generate_examples, find_any, minimal
13+
from hypothesis import given, note, settings, strategies as st
14+
from hypothesis.errors import HypothesisWarning, InvalidArgument
15+
16+
from tests.common.debug import (
17+
assert_all_examples,
18+
check_can_generate_examples,
19+
find_any,
20+
minimal,
21+
)
1722

1823

1924
@given(st.recursive(st.booleans(), st.lists, max_leaves=10))
@@ -83,8 +88,44 @@ def test_issue_1502_regression(s):
8388
st.recursive(st.none(), st.lists, max_leaves=-1),
8489
st.recursive(st.none(), st.lists, max_leaves=0),
8590
st.recursive(st.none(), st.lists, max_leaves=1.0),
91+
st.recursive(st.none(), st.lists, min_leaves=-1),
92+
st.recursive(st.none(), st.lists, min_leaves=0),
93+
st.recursive(st.none(), st.lists, min_leaves=1.0),
94+
st.recursive(st.none(), st.lists, min_leaves=10, max_leaves=5),
8695
],
8796
)
8897
def test_invalid_args(s):
8998
with pytest.raises(InvalidArgument):
9099
check_can_generate_examples(s)
100+
101+
102+
def _count_leaves(tree):
103+
if isinstance(tree, tuple):
104+
return sum(_count_leaves(child) for child in tree)
105+
return 1
106+
107+
108+
@given(st.data())
109+
@settings(max_examples=5, suppress_health_check=["filter_too_much"])
110+
def test_respects_min_leaves(data):
111+
min_leaves = data.draw(st.integers(1, 20))
112+
max_leaves = data.draw(st.integers(min_leaves, 40))
113+
note(f"{min_leaves=}")
114+
note(f"{max_leaves=}")
115+
s = st.recursive(
116+
st.none(),
117+
lambda x: st.tuples(x, x),
118+
min_leaves=min_leaves,
119+
max_leaves=max_leaves,
120+
)
121+
assert_all_examples(s, lambda tree: min_leaves <= _count_leaves(tree) <= max_leaves)
122+
123+
124+
@given(st.recursive(st.none(), lambda x: st.tuples(x, x), min_leaves=5, max_leaves=5))
125+
def test_can_set_exact_leaf_count(tree):
126+
assert _count_leaves(tree) == 5
127+
128+
129+
def test_identity_extend_warns():
130+
with pytest.warns(HypothesisWarning, match="extend=lambda x: x is a no-op"):
131+
st.recursive(st.none(), lambda x: x)

0 commit comments

Comments
 (0)