Skip to content

Commit 3f25449

Browse files
glados-vermacopybara-github
authored andcommitted
Relax type requirement from Sequence to Iterable for get_flattened phases
Add test case showing where this is useful PiperOrigin-RevId: 761319671
1 parent 635c0ab commit 3f25449

File tree

2 files changed

+27
-7
lines changed

2 files changed

+27
-7
lines changed

openhtf/util/test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1016,7 +1016,7 @@ def get_profile_filepath(cls) -> Optional[pathlib.Path]:
10161016

10171017

10181018
def get_flattened_phases(
1019-
node_collections: Sequence[
1019+
node_collections: Iterable[
10201020
Union[phase_nodes.PhaseNode, phase_collections.PhaseCollectionNode]
10211021
],
10221022
) -> Sequence[phase_nodes.PhaseNode]:

test/util/test_test.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -358,16 +358,17 @@ def test_profile_test(self):
358358
'.', os.path.sep), output)
359359

360360

361-
class GetFlattenedPhasesTest(unittest.TestCase):
361+
def no_op_phase():
362+
"""No-op phase."""
363+
362364

363-
def test_unflattens_nested_mixed_nodes(self):
365+
def make_phase(name: str):
366+
return openhtf.PhaseOptions(name=name)(no_op_phase)
364367

365-
def no_op_phase():
366-
"""No-op phase."""
367368

368-
def make_phase(name: str):
369-
return openhtf.PhaseOptions(name=name)(no_op_phase)
369+
class GetFlattenedPhasesTest(unittest.TestCase):
370370

371+
def test_unflattens_nested_mixed_nodes_list(self):
371372
nested_nodes = [
372373
make_phase('TopLevelPhase'),
373374
[make_phase('NestedPhase1'), make_phase('NestedPhase2')],
@@ -410,3 +411,22 @@ def make_phase(name: str):
410411
'NestedPhase3',
411412
],
412413
)
414+
415+
def test_unflattens_nested_mixed_nodes_iterable(self):
416+
nodes_iterable = openhtf.PhaseGroup(
417+
setup=make_phase('SetupPhase1a'),
418+
main=[make_phase('MainPhase1a'), make_phase('MainPhase1b')],
419+
teardown=make_phase('TeardownPhase1a'),
420+
).all_phases()
421+
node_names = []
422+
for node in test.get_flattened_phases(nodes_iterable):
423+
node_names.append(node.name)
424+
self.assertEqual(
425+
node_names,
426+
[
427+
'SetupPhase1a',
428+
'MainPhase1a',
429+
'MainPhase1b',
430+
'TeardownPhase1a',
431+
],
432+
)

0 commit comments

Comments
 (0)