Skip to content

Commit 75271f2

Browse files
glados-vermacopybara-github
authored andcommitted
Add assertion that verifies and yields the desired phase record
We often use this pattern when writing unit tests against a test record and need to find a phase given its name. PiperOrigin-RevId: 743225123
1 parent 7ab26ea commit 75271f2

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

openhtf/util/test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ def test_multiple(self, mock_my_plug):
127127
"""
128128

129129
from collections.abc import Callable as CollectionsCallable, Iterator
130+
import contextlib
130131
import functools
131132
import inspect
132133
import logging
@@ -786,6 +787,21 @@ def assertTestOutcomeCode(self, test_rec, code):
786787
any(details.code == code for details in test_rec.outcome_details),
787788
'No OutcomeDetails had code %s' % code)
788789

790+
@contextlib.contextmanager
791+
def assertTestHasPhaseRecord(self, test_rec, phase_name):
792+
"""Yields a PhaseRecord with the given name, else asserts."""
793+
all_phase_names = []
794+
expected_phase_rec = None
795+
for phase_rec in test_rec.phases:
796+
all_phase_names.append(phase_rec.name)
797+
if phase_rec.name == phase_name:
798+
expected_phase_rec = phase_rec
799+
self.assertIsNotNone(
800+
expected_phase_rec,
801+
msg=f'Phase "{phase_name}" not found in test phases: {all_phase_names}',
802+
)
803+
yield expected_phase_rec
804+
789805
##### PhaseRecord Assertions #####
790806

791807
def assertPhaseContinue(self, phase_record):

test/util/test_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,19 @@ def bad_test(cls_self): # pylint: disable=unused-argument
278278
with self.assertRaises(test.InvalidTestError):
279279
test.yields_phases(bad_test)(self)
280280

281+
def test_assert_test_has_phase_record(self):
282+
self.auto_mock_plugs(MyPlug)
283+
test_record = self.execute_phase_or_test(
284+
openhtf.Test(test_phase_with_shameless_plug, test_phase)
285+
)
286+
with self.subTest('phase_found'):
287+
with self.assertTestHasPhaseRecord(test_record, test_phase.name) as phase:
288+
self.assertEqual(phase.name, test_phase.name)
289+
with self.subTest('phase_not_found'):
290+
with self.assertRaises(AssertionError):
291+
with self.assertTestHasPhaseRecord(test_record, 'nonexistent_phase'):
292+
pass
293+
281294

282295
class PhaseProfilingTest(test.TestCase):
283296
"""Test profiling an OpenHTF phase in unit testing.

0 commit comments

Comments
 (0)