Skip to content

Commit f298403

Browse files
committed
Update tests to use the new Receiver.matches method
Signed-off-by: Sahas Subramanian <[email protected]>
1 parent 8b2b087 commit f298403

File tree

3 files changed

+32
-33
lines changed

3 files changed

+32
-33
lines changed

tests/test_file_watcher_integration.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import pytest
1111

12-
from frequenz.channels import ReceiverStoppedError, select, selected_from
12+
from frequenz.channels import ReceiverStoppedError, select
1313
from frequenz.channels.file_watcher import EventType, FileWatcher
1414
from frequenz.channels.timer import SkipMissedAndDrift, Timer
1515

@@ -32,9 +32,9 @@ async def test_file_watcher(tmp_path: pathlib.Path) -> None:
3232
timer = Timer(timedelta(seconds=0.1), SkipMissedAndDrift())
3333

3434
async for selected in select(file_watcher, timer):
35-
if selected_from(selected, timer):
35+
if timer.matches(selected):
3636
filename.write_text(f"{selected.message}")
37-
elif selected_from(selected, file_watcher):
37+
elif file_watcher.matches(selected):
3838
event_type = EventType.CREATE if number_of_writes == 0 else EventType.MODIFY
3939
event = selected.message
4040
# If we receive updates for the directory itself, they should be only
@@ -93,18 +93,18 @@ async def test_file_watcher_deletes(tmp_path: pathlib.Path) -> None:
9393
# D: Delete
9494
# E: FileWatcher Event
9595
async for selected in select(file_watcher, write_timer, deletion_timer):
96-
if selected_from(selected, write_timer):
96+
if write_timer.matches(selected):
9797
if number_of_write >= 2 and number_of_events == 0:
9898
continue
9999
filename.write_text(f"{selected.message}")
100100
number_of_write += 1
101-
elif selected_from(selected, deletion_timer):
101+
elif deletion_timer.matches(selected):
102102
# Avoid removing the file twice
103103
if not pathlib.Path(filename).is_file():
104104
continue
105105
os.remove(filename)
106106
number_of_deletes += 1
107-
elif selected_from(selected, file_watcher):
107+
elif file_watcher.matches(selected):
108108
number_of_events += 1
109109
if number_of_events >= 2:
110110
break
@@ -135,9 +135,9 @@ async def test_file_watcher_exit_iterator(tmp_path: pathlib.Path) -> None:
135135
timer = Timer(timedelta(seconds=0.1), SkipMissedAndDrift())
136136

137137
async for selected in select(file_watcher, timer):
138-
if selected_from(selected, timer):
138+
if timer.matches(selected):
139139
filename.write_text(f"{selected.message}")
140-
elif selected_from(selected, file_watcher):
140+
elif file_watcher.matches(selected):
141141
number_of_writes += 1
142142
if number_of_writes == expected_number_of_writes:
143143
file_watcher._stop_event.set() # pylint: disable=protected-access

tests/test_select.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import pytest
99

10-
from frequenz.channels import Receiver, ReceiverStoppedError, Selected, selected_from
10+
from frequenz.channels import Receiver, ReceiverStoppedError, Selected
1111

1212

1313
class TestSelected:
@@ -19,7 +19,7 @@ def test_with_message(self) -> None:
1919
recv.consume.return_value = 42
2020
selected = Selected[int](recv)
2121

22-
assert selected_from(selected, recv)
22+
assert recv.matches(selected)
2323
assert selected.message == 42
2424
assert selected.exception is None
2525
assert not selected.was_stopped
@@ -31,7 +31,7 @@ def test_with_exception(self) -> None:
3131
recv.consume.side_effect = exception
3232
selected = Selected[int](recv)
3333

34-
assert selected_from(selected, recv)
34+
assert recv.matches(selected)
3535
with pytest.raises(Exception, match="test"):
3636
_ = selected.message
3737
assert selected.exception is exception
@@ -44,7 +44,7 @@ def test_with_stopped(self) -> None:
4444
recv.consume.side_effect = exception
4545
selected = Selected[int](recv)
4646

47-
assert selected_from(selected, recv)
47+
assert recv.matches(selected)
4848
with pytest.raises(
4949
ReceiverStoppedError,
5050
match=r"Receiver <MagicMock spec='_GenericAlias' id='\d+'> was stopped",

tests/test_select_integration.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ class at a time.
2121
Selected,
2222
UnhandledSelectedError,
2323
select,
24-
selected_from,
2524
)
2625
from frequenz.channels.event import Event
2726

@@ -87,7 +86,7 @@ def assert_received_from(
8786
number is negative, a > check is performed with the absolute value. If
8887
it is 0, no check is performed.
8988
"""
90-
assert selected_from(selected, receiver)
89+
assert receiver.matches(selected)
9190
assert selected.message is None
9291
assert selected.exception is None
9392
assert not selected.was_stopped
@@ -120,7 +119,7 @@ def assert_receiver_stopped(
120119
number is negative, a > check is performed with the absolute value. If
121120
it is 0, no check is performed.
122121
"""
123-
assert selected_from(selected, receiver)
122+
assert receiver.matches(selected)
124123
assert selected.was_stopped
125124
assert isinstance(selected.exception, ReceiverStoppedError)
126125
assert selected.exception.receiver is receiver
@@ -245,33 +244,33 @@ async def test_break(
245244
"""Test that break works."""
246245
selected: Selected[Any] | None = None
247246
async for selected in select(self.recv1, self.recv2, self.recv3):
248-
if selected_from(selected, self.recv1):
247+
if self.recv1.matches(selected):
249248
continue
250-
if selected_from(selected, self.recv2):
249+
if self.recv2.matches(selected):
251250
continue
252-
if selected_from(selected, self.recv3):
251+
if self.recv3.matches(selected):
253252
break
254253

255254
assert selected is not None
256255
self.assert_received_from(selected, self.recv3, at_time=2)
257256

258257
async for selected in select(self.recv1, self.recv2, self.recv3):
259-
if selected_from(selected, self.recv1):
258+
if self.recv1.matches(selected):
260259
continue
261-
if selected_from(selected, self.recv2):
260+
if self.recv2.matches(selected):
262261
break
263-
if selected_from(selected, self.recv3):
262+
if self.recv3.matches(selected):
264263
continue
265264

266265
assert selected is not None
267266
self.assert_received_from(selected, self.recv2, at_time=6)
268267

269268
async for selected in select(self.recv1, self.recv2, self.recv3):
270-
if selected_from(selected, self.recv1):
269+
if self.recv1.matches(selected):
271270
continue
272-
if selected_from(selected, self.recv2):
271+
if self.recv2.matches(selected):
273272
continue
274-
if selected_from(selected, self.recv3):
273+
if self.recv3.matches(selected):
275274
break
276275

277276
assert selected is not None
@@ -281,7 +280,7 @@ async def test_break(
281280
assert self.recv3.is_stopped
282281

283282
async for selected in select(self.recv2):
284-
if selected_from(selected, self.recv2):
283+
if self.recv2.matches(selected):
285284
continue
286285

287286
self.assert_receiver_stopped(
@@ -300,9 +299,9 @@ async def test_missed_select_from(
300299
selected: Selected[Any] | None = None
301300
with pytest.raises(UnhandledSelectedError) as excinfo:
302301
async for selected in select(self.recv1, self.recv2, self.recv3):
303-
if selected_from(selected, self.recv1):
302+
if self.recv1.matches(selected):
304303
continue
305-
if selected_from(selected, self.recv2):
304+
if self.recv2.matches(selected):
306305
continue
307306

308307
assert False, "Should not reach this point"
@@ -392,11 +391,11 @@ async def test_multiple_ready(
392391
received.clear()
393392
last_time = now
394393

395-
if selected_from(selected, self.recv1):
394+
if self.recv1.matches(selected):
396395
received.add(self.recv1.name)
397-
elif selected_from(selected, self.recv2):
396+
elif self.recv2.matches(selected):
398397
received.add(self.recv2.name)
399-
elif selected_from(selected, self.recv3):
398+
elif self.recv3.matches(selected):
400399
received.add(self.recv3.name)
401400
else:
402401
assert False, "Should not reach this point"
@@ -425,11 +424,11 @@ def test_tasks_are_cleaned_up_with_break(self) -> None:
425424
async def run() -> None:
426425
task = loop.create_task(self.run_multiple_ready())
427426
async for selected in select(self.recv1, self.recv2, self.recv3):
428-
if selected_from(selected, self.recv1):
427+
if self.recv1.matches(selected):
429428
continue
430-
if selected_from(selected, self.recv2):
429+
if self.recv2.matches(selected):
431430
continue
432-
if selected_from(selected, self.recv3):
431+
if self.recv3.matches(selected):
433432
break
434433

435434
task.cancel()

0 commit comments

Comments
 (0)