Skip to content

Commit 8845b5b

Browse files
authored
fix(call_stack): match spy IDs in get_by_rehearsals (#55)
Closes #54
1 parent 5ce3bc5 commit 8845b5b

File tree

3 files changed

+11
-7
lines changed

3 files changed

+11
-7
lines changed

decoy/call_stack.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def consume_when_rehearsal(self) -> WhenRehearsal:
2323
"""
2424
try:
2525
call = self._stack[-1]
26-
except KeyError:
26+
except IndexError:
2727
raise MissingRehearsalError()
2828
if not isinstance(call, SpyCall):
2929
raise MissingRehearsalError()
@@ -47,12 +47,12 @@ def consume_verify_rehearsals(self, count: int) -> List[VerifyRehearsal]:
4747
return rehearsals
4848

4949
def get_by_rehearsals(self, rehearsals: Sequence[VerifyRehearsal]) -> List[SpyCall]:
50-
"""Get a list of all non-rehearsal calls to the given Spy IDs."""
50+
"""Get all non-rehearsal calls to the spies in the given rehearsals."""
5151
return [
5252
call
5353
for call in self._stack
5454
if isinstance(call, SpyCall)
55-
and any(rehearsal == call for rehearsal in rehearsals)
55+
and any(rehearsal.spy_id == call.spy_id for rehearsal in rehearsals)
5656
]
5757

5858
def get_all(self) -> List[BaseSpyCall]:

decoy/verifier.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@ def verify(
1818
if times is not None:
1919
if len(calls) == times:
2020
return None
21+
2122
else:
22-
for i, call in enumerate(calls):
23+
for i in range(len(calls)):
2324
calls_subset = calls[i : i + len(rehearsals)]
2425

2526
if calls_subset == rehearsals:

tests/test_call_stack.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,11 @@ def test_push_and_consume_when_rehearsal() -> None:
2121
def test_consume_when_rehearsal_raises_empty_error() -> None:
2222
"""It should raise an error if the stack is empty on pop."""
2323
subject = CallStack()
24-
call = SpyCall(spy_id=42, spy_name="my_spy", args=(), kwargs={})
2524

25+
with pytest.raises(MissingRehearsalError):
26+
subject.consume_when_rehearsal()
27+
28+
call = SpyCall(spy_id=42, spy_name="my_spy", args=(), kwargs={})
2629
subject.push(call)
2730
subject.consume_when_rehearsal()
2831

@@ -64,7 +67,7 @@ def test_consume_verify_rehearsals_raises_error() -> None:
6467

6568

6669
def test_get_by_rehearsal() -> None:
67-
"""It can get a list of calls made matching a given rehearsal."""
70+
"""It can get a list of calls made matching spy IDs of given rehearsals."""
6871
subject = CallStack()
6972
call_1 = SpyCall(spy_id=101, spy_name="spy_1", args=(1,), kwargs={})
7073
call_2 = SpyCall(spy_id=101, spy_name="spy_1", args=(2,), kwargs={})
@@ -88,7 +91,7 @@ def test_get_by_rehearsal() -> None:
8891
VerifyRehearsal(spy_id=202, spy_name="spy_2", args=(1,), kwargs={}),
8992
]
9093
)
91-
assert result == [call_3]
94+
assert result == [call_1, call_3, call_4]
9295

9396
result = subject.get_by_rehearsals(
9497
[VerifyRehearsal(spy_id=303, spy_name="spy_3", args=(1,), kwargs={})]

0 commit comments

Comments
 (0)