Skip to content

Commit 70550f9

Browse files
authored
Phase Sequence: modifier functions that copy must use attr.evolve (#961)
PiperOrigin-RevId: 337203373
1 parent 94bebbc commit 70550f9

File tree

2 files changed

+62
-5
lines changed

2 files changed

+62
-5
lines changed

openhtf/core/phase_collections.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -155,28 +155,33 @@ def _asdict(self) -> Dict[Text, Any]:
155155

156156
def with_args(self: SequenceClassT, **kwargs: Any) -> SequenceClassT:
157157
"""Send these keyword-arguments when phases are called."""
158-
return type(self)(
158+
return attr.evolve(
159+
self,
159160
nodes=tuple(n.with_args(**kwargs) for n in self.nodes),
160161
name=util.format_string(self.name, kwargs))
161162

162163
def with_plugs(self: SequenceClassT,
163164
**subplugs: Type[base_plugs.BasePlug]) -> SequenceClassT:
164165
"""Substitute plugs for placeholders for this phase, error on unknowns."""
165-
return type(self)(
166+
return attr.evolve(
167+
self,
166168
nodes=tuple(n.with_plugs(**subplugs) for n in self.nodes),
167169
name=util.format_string(self.name, subplugs))
168170

169171
def load_code_info(self: SequenceClassT) -> SequenceClassT:
170172
"""Load coded info for all contained phases."""
171-
return type(self)(
172-
nodes=tuple(n.load_code_info() for n in self.nodes), name=self.name)
173+
return attr.evolve(
174+
self,
175+
nodes=tuple(n.load_code_info() for n in self.nodes),
176+
name=self.name)
173177

174178
def apply_to_all_phases(
175179
self: SequenceClassT, func: Callable[[phase_descriptor.PhaseDescriptor],
176180
phase_descriptor.PhaseDescriptor]
177181
) -> SequenceClassT:
178182
"""Apply func to all contained phases."""
179-
return type(self)(
183+
return attr.evolve(
184+
self,
180185
nodes=tuple(n.apply_to_all_phases(func) for n in self.nodes),
181186
name=self.name)
182187

test/core/phase_branches_test.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,58 @@ def test_as_dict(self):
7272
}
7373
self.assertEqual(expected, branch._asdict())
7474

75+
def test_with_args(self):
76+
branch = phase_branches.BranchSequence(
77+
phase_branches.DiagnosisCondition.on_all(BranchDiagResult.SET),
78+
nodes=(run_phase,),
79+
name='name_{arg}')
80+
expected = phase_branches.BranchSequence(
81+
phase_branches.DiagnosisCondition.on_all(BranchDiagResult.SET),
82+
nodes=(run_phase.with_args(arg=1),),
83+
name='name_1')
84+
85+
self.assertEqual(expected, branch.with_args(arg=1))
86+
87+
def test_with_plugs(self):
88+
89+
class MyPlug(htf.BasePlug):
90+
pass
91+
92+
branch = phase_branches.BranchSequence(
93+
phase_branches.DiagnosisCondition.on_all(BranchDiagResult.SET),
94+
nodes=(run_phase,),
95+
name='name_{my_plug.__name__}')
96+
expected = phase_branches.BranchSequence(
97+
phase_branches.DiagnosisCondition.on_all(BranchDiagResult.SET),
98+
nodes=(run_phase.with_plugs(my_plug=MyPlug),),
99+
name='name_MyPlug')
100+
101+
self.assertEqual(expected, branch.with_plugs(my_plug=MyPlug))
102+
103+
def test_load_code_info(self):
104+
branch = phase_branches.BranchSequence(
105+
phase_branches.DiagnosisCondition.on_all(BranchDiagResult.SET),
106+
nodes=(run_phase,))
107+
expected = phase_branches.BranchSequence(
108+
phase_branches.DiagnosisCondition.on_all(BranchDiagResult.SET),
109+
nodes=(run_phase.load_code_info(),))
110+
111+
self.assertEqual(expected, branch.load_code_info())
112+
113+
def test_apply_to_all_phases(self):
114+
115+
def do_rename(phase):
116+
return _rename(phase, 'blah_blah')
117+
118+
branch = phase_branches.BranchSequence(
119+
phase_branches.DiagnosisCondition.on_all(BranchDiagResult.SET),
120+
nodes=(run_phase,))
121+
expected = phase_branches.BranchSequence(
122+
phase_branches.DiagnosisCondition.on_all(BranchDiagResult.SET),
123+
nodes=(do_rename(run_phase),))
124+
125+
self.assertEqual(expected, branch.apply_to_all_phases(do_rename))
126+
75127

76128
class BranchSequenceIntegrationTest(htf_test.TestCase):
77129

0 commit comments

Comments
 (0)