Skip to content

Commit 32599d3

Browse files
raymond13513copybara-github
authored andcommitted
Create abstraction for disruption event to handle async vs sync operations for Migration
PiperOrigin-RevId: 842572782
1 parent e8cb07e commit 32599d3

File tree

4 files changed

+42
-37
lines changed

4 files changed

+42
-37
lines changed

perfkitbenchmarker/time_triggers/base_disruption_trigger.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,19 +64,30 @@ class BaseDisruptionTrigger(base_time_trigger.BaseTimeTrigger):
6464
def __init__(self, delay: int):
6565
super().__init__(delay)
6666
self.metadata = {}
67-
self.disruption_ends = None
67+
self.disruption_events: MutableSequence[DisruptionEvent] = []
6868

6969
def TriggerMethod(self, vm: virtual_machine.VirtualMachine):
70-
"""Trigger the disruption."""
70+
"""Trigger the disruption.
71+
72+
Implementation of this needs to modify the disruption_events list if the
73+
operation sync.
74+
75+
Args:
76+
vm: The VirtualMachine to trigger the disruption on.
77+
"""
7178
raise NotImplementedError()
7279

7380
def SetUp(self):
7481
"""See base class."""
7582
raise NotImplementedError()
7683

77-
def WaitForDisruption(self) -> MutableSequence[DisruptionEvent]:
78-
"""Wait for disruption to end and return the end time."""
79-
return []
84+
def WaitForDisruption(self) -> None:
85+
"""Wait for disruption to end and return the end time.
86+
87+
Only need to implement this if the operation is async. If the operation is
88+
async append the events to the disruption_events list.
89+
"""
90+
pass
8091

8192
def GetMetadataForTrigger(self, event: DisruptionEvent) -> Dict[str, Any]:
8293
"""Get the metadata for the trigger and append it to the samples."""
@@ -97,11 +108,11 @@ def AppendSamples(
97108
def generate_disruption_total_time_samples() -> (
98109
MutableSequence[sample.Sample]
99110
):
100-
events = self.WaitForDisruption()
111+
self.WaitForDisruption()
101112

102113
# Host maintenance is in s
103114
self.disruption_ends = max(
104-
[float(d.end_time) * 1000 for d in events],
115+
[float(d.end_time) * 1000 for d in self.disruption_events],
105116
default=0,
106117
)
107118

@@ -120,7 +131,7 @@ def generate_disruption_total_time_samples() -> (
120131
'seconds',
121132
sample_metadata | self.GetMetadataForTrigger(d),
122133
)
123-
for d in events
134+
for d in self.disruption_events
124135
]
125136

126137
samples += generate_disruption_total_time_samples()
@@ -251,7 +262,12 @@ def _AggregateThroughputSample(
251262

252263
def GetDisruptionEnds(self) -> float | None:
253264
"""Get the disruption ends."""
254-
return self.disruption_ends
265+
if self.disruption_events:
266+
return max(
267+
[float(d.end_time) * 1000 for d in self.disruption_events],
268+
default=None,
269+
)
270+
return None
255271

256272
def _ComputeLossPercentile(
257273
self,

perfkitbenchmarker/time_triggers/maintenance_simulation_trigger.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414
"""Module containning methods for triggering maintenance simulation."""
1515

16-
from collections.abc import MutableSequence
1716
import datetime
1817
import json
1918
import logging
@@ -364,26 +363,13 @@ def SetUp(self):
364363
for helper in self.gce_simulate_maintenance_helpers.values():
365364
helper.SetupLMNotification()
366365

367-
def WaitForDisruption(
368-
self,
369-
) -> MutableSequence[base_disruption_trigger.DisruptionEvent]:
366+
def WaitForDisruption(self) -> None:
370367
"""Wait for the disruption to end and return the end time."""
371368
if self.capture_live_migration_timestamps:
372369
# Block test exit until LM ended.
373-
lm_events = []
374370
for helper in self.gce_simulate_maintenance_helpers.values():
375371
helper.WaitLMNotificationRelease()
376-
lm_events.append(helper.CollectLMNotificationsTime())
377-
return lm_events
378-
else:
379-
return []
380-
381-
def GetDisruptionEnds(self) -> float | None:
382-
"""Get the disruption ends."""
383-
if self.capture_live_migration_timestamps:
384-
# lm ends is computed from LM notification
385-
return self.disruption_ends
386-
return None
372+
self.disruption_events.append(helper.CollectLMNotificationsTime())
387373

388374
@property
389375
def trigger_name(self) -> str:

tests/time_triggers/base_disruption_trigger_test.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class BaseDisruptionTriggerTest(pkb_common_test_case.PkbCommonTestCase):
4949
def setUp(self):
5050
super().setUp()
5151
self.trigger = TestBaseDisruptionTrigger()
52+
self.maxDiff = 100000
5253
self.enter_context(
5354
mock.patch.object(self.trigger, 'WaitForDisruption', autospec=True)
5455
)
@@ -63,7 +64,7 @@ def testAppendSamples(self):
6364
}
6465
s = []
6566
vm_spec = mock.MagicMock(spec=benchmark_spec.BenchmarkSpec)
66-
self.trigger.WaitForDisruption.return_value = [
67+
self.trigger.disruption_events = [
6768
base_disruption_trigger.DisruptionEvent(
6869
total_time=10, end_time=10, start_time=4
6970
)
@@ -79,7 +80,7 @@ def testAppendLossFunctionWithDegradationPercent(self):
7980
FLAGS.maintenance_degradation_percent = 90
8081
vm_spec = mock.MagicMock(spec=benchmark_spec.BenchmarkSpec)
8182
self.trigger.trigger_time = datetime.datetime.fromtimestamp(2)
82-
self.trigger.WaitForDisruption.return_value = []
83+
self.trigger.disruption_events = []
8384
self.enter_context(
8485
mock.patch.object(
8586
self.trigger, 'GetDisruptionEnds', return_value=None, autospec=True
@@ -208,7 +209,7 @@ def testAppendLossFunctionWithDegradationPercent(self):
208209
def testAppendLossFunctionWithMissingTimeStampsWithRegression(self):
209210
vm_spec = mock.MagicMock(spec=benchmark_spec.BenchmarkSpec)
210211
self.trigger.trigger_time = datetime.datetime.fromtimestamp(2)
211-
self.trigger.WaitForDisruption.return_value = []
212+
self.trigger.disruption_events = []
212213
self.enter_context(
213214
mock.patch.object(
214215
self.trigger, 'GetDisruptionEnds', return_value=None, autospec=True
@@ -336,7 +337,7 @@ def testAppendLossFunctionWithMissingTimeStampsWithRegression(self):
336337
@mock.patch.object(time, 'time', mock.MagicMock(return_value=0))
337338
def testAppendLossFunctionWithMissingTimeStampsNoRegression(self):
338339
self.trigger.trigger_time = datetime.datetime.fromtimestamp(2)
339-
self.trigger.WaitForDisruption.return_value = []
340+
self.trigger.disruption_events = []
340341
self.enter_context(
341342
mock.patch.object(
342343
self.trigger, 'GetDisruptionEnds', return_value=None, autospec=True
@@ -465,7 +466,7 @@ def testAppendLossFunctionWithMissingTimeStampsNoRegression(self):
465466
@mock.patch.object(time, 'time', mock.MagicMock(return_value=0))
466467
def testAppendLossFunctionSamples(self):
467468
vm_spec = mock.MagicMock(spec=benchmark_spec.BenchmarkSpec)
468-
self.trigger.WaitForDisruption.return_value = []
469+
self.trigger.disruption_events = []
469470
self.enter_context(
470471
mock.patch.object(
471472
self.trigger, 'GetDisruptionEnds', return_value=None, autospec=True
@@ -611,7 +612,7 @@ def testAppendLossFunctionSamplesWithDisruptionEnds(self):
611612
)
612613
samples = [s]
613614
self.trigger.trigger_time = datetime.datetime.fromtimestamp(4)
614-
self.trigger.WaitForDisruption.return_value = [
615+
self.trigger.disruption_events = [
615616
base_disruption_trigger.DisruptionEvent(
616617
total_time=100, end_time=8, start_time=4
617618
)
@@ -811,7 +812,7 @@ def testAppendLossFunctionSamplesContainsMetadata(self):
811812
)
812813
samples = [s]
813814
self.trigger.trigger_time = datetime.datetime.fromtimestamp(4)
814-
self.trigger.WaitForDisruption.return_value = [
815+
self.trigger.disruption_events = [
815816
base_disruption_trigger.DisruptionEvent(
816817
total_time=100, end_time=8, start_time=4
817818
)
@@ -1016,7 +1017,7 @@ def testAppendLossFunctionSamplesHandleTimeDrift(self):
10161017
)
10171018
samples = [s]
10181019
self.trigger.trigger_time = datetime.datetime.fromtimestamp(4)
1019-
self.trigger.WaitForDisruption.return_value = [
1020+
self.trigger.disruption_events = [
10201021
base_disruption_trigger.DisruptionEvent(
10211022
total_time=100, end_time=11, start_time=4
10221023
)
@@ -1175,7 +1176,7 @@ def testMaintenanceEventTriggerAppendSamplesWithMaintenanceDegradationWindow(
11751176
)
11761177
samples = [s]
11771178
self.trigger.trigger_time = datetime.datetime.fromtimestamp(4)
1178-
self.trigger.WaitForDisruption.return_value = [
1179+
self.trigger.disruption_events = [
11791180
base_disruption_trigger.DisruptionEvent(
11801181
total_time=100, end_time=11, start_time=4
11811182
)
@@ -1334,7 +1335,7 @@ def testMaintenanceEventTriggerAppendSamplesWithRegressionOutsideMaintenanceWind
13341335
)
13351336
samples = [s]
13361337
self.trigger.trigger_time = datetime.datetime.fromtimestamp(4)
1337-
self.trigger.WaitForDisruption.return_value = [
1338+
self.trigger.disruption_events = [
13381339
base_disruption_trigger.DisruptionEvent(
13391340
total_time=100, end_time=8, start_time=4
13401341
)

tests/time_triggers/maintenance_simulation_trigger_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,9 @@ def testWaitForDisruption(self):
7575
s = []
7676
trigger = maintenance_simulation_trigger.MaintenanceEventTrigger()
7777
self.enter_context(
78-
mock.patch.object(trigger, 'WaitForDisruption', return_value=[event])
78+
mock.patch.object(trigger, 'WaitForDisruption', return_value=None)
7979
)
80+
trigger.disruption_events = [event]
8081
trigger.capture_live_migration_timestamps = True
8182
trigger.vms = [vm]
8283
trigger.AppendSamples(None, vm_spec, s)
@@ -112,8 +113,9 @@ def testWaitForDisruptionReturnsCorrectValue(self):
112113
].CollectLMNotificationsTime.return_value = event
113114
trigger.capture_live_migration_timestamps = True
114115
trigger.vms = [vm]
116+
trigger.WaitForDisruption()
115117
self.assertEqual(
116-
trigger.WaitForDisruption(),
118+
trigger.disruption_events,
117119
[event],
118120
)
119121

0 commit comments

Comments
 (0)