Skip to content

Commit 5b112b7

Browse files
committed
Give shared the same label as its base strategy
1 parent 030ea39 commit 5b112b7

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

hypothesis-python/src/hypothesis/strategies/_internal/shared.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,16 @@ def __init__(self, base: SearchStrategy[Ex], key: Optional[Hashable] = None):
2323
super().__init__()
2424
self.key = key
2525
self.base = base
26-
while isinstance(self.base, SharedStrategy):
27-
# Unwrap nested shares
28-
self.base = self.base.base
2926

3027
def __repr__(self) -> str:
3128
if self.key is not None:
3229
return f"shared({self.base!r}, key={self.key!r})"
3330
else:
3431
return f"shared({self.base!r})"
3532

33+
def calc_label(self) -> int:
34+
return self.base.calc_label()
35+
3636
# Ideally would be -> Ex, but key collisions with different-typed values are
3737
# possible. See https://github.com/HypothesisWorks/hypothesis/issues/4301.
3838
def do_draw(self, data: ConjectureData) -> Any:
@@ -44,7 +44,7 @@ def do_draw(self, data: ConjectureData) -> Any:
4444
drawn, other = data._shared_strategy_draws[key]
4545

4646
# Check that the strategies shared under this key are equivalent
47-
if self.base.label != other.base.label:
47+
if self.label != other.label:
4848
warnings.warn(
4949
f"Different strategies are shared under {key=}. This"
5050
" risks drawing values that are not valid examples for the strategy,"

hypothesis-python/tests/cover/test_direct_strategies.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -661,11 +661,12 @@ def test_it(a, b):
661661
def test_compatible_nested_shared_strategies_do_not_warn():
662662
shared_a = st.shared(st.integers(), key="share")
663663
shared_b = st.shared(st.integers(), key="share")
664-
shared_c = st.shared(shared_a, key="share")
664+
shared_c = st.shared(shared_a, key="nested_share")
665+
shared_d = st.shared(shared_b, key="nested_share")
665666

666-
@given(shared_a, shared_b, shared_c)
667+
@given(shared_a, shared_b, shared_c, shared_d)
667668
@settings(max_examples=10, phases=[Phase.generate])
668-
def test_it(a, b, c):
669-
assert a == b == c
669+
def test_it(a, b, c, d):
670+
assert a == b == c == d
670671

671672
test_it()

0 commit comments

Comments
 (0)