Skip to content

Commit 0d8b2a0

Browse files
yuyantingzerocopybara-github
authored andcommitted
Initial commit of cluster resource.
PiperOrigin-RevId: 720709547
1 parent 2425f59 commit 0d8b2a0

File tree

3 files changed

+337
-0
lines changed

3 files changed

+337
-0
lines changed

perfkitbenchmarker/benchmark_spec.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from perfkitbenchmarker import benchmark_status
3030
from perfkitbenchmarker import capacity_reservation
3131
from perfkitbenchmarker import cloud_tpu
32+
from perfkitbenchmarker import cluster
3233
from perfkitbenchmarker import container_service
3334
from perfkitbenchmarker import context
3435
from perfkitbenchmarker import data_discovery_service
@@ -180,6 +181,7 @@ def __init__(
180181
self.always_call_cleanup = pkb_flags.ALWAYS_CALL_CLEANUP.value
181182
self.dpb_service: dpb_service.BaseDpbService = None
182183
self.container_cluster: container_service.BaseContainerCluster = None
184+
self.cluster: cluster.BaseCluster = None
183185
self.key = None
184186
self.relational_db = None
185187
self.non_relational_db = None
@@ -316,6 +318,7 @@ def ConstructResources(self):
316318
self.ConstructBaseJob()
317319
self.ConstructMemoryStore()
318320
self.ConstructPinecone()
321+
self.ConstructCluster()
319322

320323
def ConstructContainerCluster(self):
321324
"""Create the container cluster."""
@@ -332,6 +335,16 @@ def ConstructContainerCluster(self):
332335
)
333336
self.resources.append(self.container_cluster)
334337

338+
def ConstructCluster(self):
339+
"""Create the cluster."""
340+
if self.config.cluster is None:
341+
return
342+
cloud = self.config.cluster.cloud
343+
providers.LoadProvider(cloud)
344+
cluster_class = cluster.GetClusterClass(cloud)
345+
self.cluster = cluster_class(self.config.cluster)
346+
self.resources.append(self.cluster)
347+
335348
def ConstructContainerRegistry(self):
336349
"""Create the container registry."""
337350
if self.config.container_registry is None:
@@ -880,6 +893,9 @@ def Provision(self):
880893
if self.container_cluster:
881894
self.container_cluster.Create()
882895

896+
if self.cluster:
897+
self.cluster.Create()
898+
883899
# do after network setup but before VM created
884900
if self.nfs_service and self.nfs_service.CLOUD != nfs_service.UNMANAGED:
885901
self.nfs_service.Create()
@@ -1043,6 +1059,9 @@ def Delete(self):
10431059
self.container_cluster.DeleteContainers()
10441060
self.container_cluster.Delete()
10451061

1062+
if self.cluster:
1063+
self.cluster.Delete()
1064+
10461065
for net in self.networks.values():
10471066
try:
10481067
net.Delete()

perfkitbenchmarker/cluster.py

Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
# Copyright 2025 PerfKitBenchmarker Authors. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Class to represent a Cluster object."""
16+
import typing
17+
from typing import Callable, List, Tuple
18+
19+
from absl import flags
20+
from perfkitbenchmarker import disk
21+
from perfkitbenchmarker import errors
22+
from perfkitbenchmarker import linux_virtual_machine
23+
from perfkitbenchmarker import resource
24+
from perfkitbenchmarker import static_virtual_machine
25+
from perfkitbenchmarker import virtual_machine
26+
from perfkitbenchmarker.configs import option_decoders
27+
from perfkitbenchmarker.configs import spec
28+
from perfkitbenchmarker.configs import vm_group_decoders
29+
30+
31+
FLAGS = flags.FLAGS
32+
TEMPLATE_FILE = flags.DEFINE_string(
33+
'cluster_template_file',
34+
None,
35+
'The template file to be used to create the cluster. None by default, '
36+
'each provider has a default template file.',
37+
)
38+
39+
40+
class BaseClusterSpec(spec.BaseSpec):
41+
"""Storing various data about HPC/ML cluster.
42+
43+
Attributes:
44+
zone: The region / zone the in which to launch the cluster.
45+
machine_type: The provider-specific instance type (e.g. n1-standard-8).
46+
image: The disk image to boot from.
47+
"""
48+
49+
SPEC_TYPE = 'BaseClusterSpec'
50+
SPEC_ATTRS = ['CLOUD']
51+
CLOUD = None
52+
53+
@classmethod
54+
def _ApplyFlags(cls, config_values, flag_values):
55+
"""Overrides config values with flag values.
56+
57+
Can be overridden by derived classes to add support for specific flags.
58+
59+
Args:
60+
config_values: dict mapping config option names to provided values. Is
61+
modified by this function.
62+
flag_values: flags.FlagValues. Runtime flags that may override the
63+
provided config values.
64+
65+
Returns:
66+
dict mapping config option names to values derived from the config
67+
values or flag values.
68+
"""
69+
super()._ApplyFlags(config_values, flag_values)
70+
if flag_values['cloud'].present:
71+
config_values['cloud'] = flag_values.cloud
72+
if flag_values['cluster_template_file'].present:
73+
config_values['template'] = flag_values.cluster_template_file
74+
cloud = config_values['cloud']
75+
# only apply to workers
76+
if flag_values['num_vms'].present:
77+
config_values['workers']['vm_count'] = flag_values['num_vms'].value
78+
# flags should be applied to workers and headnode
79+
if flag_values['zone'].present:
80+
config_values['workers']['vm_spec'][cloud]['zone'] = flag_values[
81+
'zone'
82+
].value[0]
83+
config_values['headnode']['vm_spec'][cloud]['zone'] = flag_values[
84+
'zone'
85+
].value[0]
86+
for flag_name in ('os_type', 'cloud'):
87+
if flag_values[flag_name].present:
88+
config_values['workers'][flag_name] = flag_values[flag_name].value
89+
config_values['headnode'][flag_name] = flag_values[flag_name].value
90+
91+
@classmethod
92+
def _GetOptionDecoderConstructions(cls):
93+
"""Gets decoder classes and constructor args for each configurable option.
94+
95+
Can be overridden by derived classes to add options or impose additional
96+
requirements on existing options.
97+
98+
Returns:
99+
dict. Maps option name string to a (ConfigOptionDecoder class, dict) pair.
100+
The pair specifies a decoder class and its __init__() keyword
101+
arguments to construct in order to decode the named option.
102+
"""
103+
result = super()._GetOptionDecoderConstructions()
104+
result.update({
105+
'workers': (vm_group_decoders.VmGroupSpecDecoder, {}),
106+
'headnode': (vm_group_decoders.VmGroupSpecDecoder, {}),
107+
'cloud': (option_decoders.StringDecoder, {'default': None}),
108+
'template': (option_decoders.StringDecoder, {'default': None}),
109+
})
110+
return result
111+
112+
113+
def GetClusterSpecClass(cloud: str):
114+
"""Returns the cluster spec class corresponding to the given service."""
115+
return spec.GetSpecClass(BaseClusterSpec, CLOUD=cloud)
116+
117+
118+
def GetClusterClass(cloud: str):
119+
"""Returns the cluster spec class corresponding to the given service."""
120+
return resource.GetResourceClass(BaseCluster, CLOUD=cloud)
121+
122+
123+
class BaseCluster(resource.BaseResource):
124+
"""Base class for cluster resources.
125+
126+
This class holds cluster-level methods and attributes.
127+
128+
Attributes:
129+
image: The disk image used to boot.
130+
machine_type: The provider-specific instance type for worker VMs.
131+
zone: The region / zone the VM was launched in.
132+
headnode_vm: The headnode VM.
133+
worker_vms: Internal IP address.
134+
"""
135+
136+
RESOURCE_TYPE = 'BaseCluster'
137+
REQUIRED_ATTRS = ['CLOUD']
138+
139+
def __init__(self, cluster_spec: BaseClusterSpec):
140+
"""Initialize BaseCluster class.
141+
142+
Args:
143+
cluster_spec: cluster.BaseBaseClusterSpec object.
144+
"""
145+
super().__init__()
146+
self.zone: str = cluster_spec.workers.vm_spec.zone
147+
self.machine_type: str = cluster_spec.workers.vm_spec.machine_type
148+
self.worker_machine_type: str = self.machine_type
149+
self.headnode_machine_type: str = cluster_spec.headnode.vm_spec.machine_type
150+
self.headnode_spec: virtual_machine.BaseVmSpec = (
151+
cluster_spec.headnode.vm_spec
152+
)
153+
self.image: str = cluster_spec.workers.vm_spec.image
154+
self.workers_spec: virtual_machine.BaseVmSpec = cluster_spec.workers.vm_spec
155+
self.workers_static_disk_spec: disk.BaseDiskSpec = (
156+
cluster_spec.workers.disk_spec
157+
)
158+
self.workers_static_disk: static_virtual_machine.StaticDisk | None = (
159+
static_virtual_machine.StaticDisk(self.workers_static_disk_spec)
160+
if self.workers_static_disk_spec
161+
else None
162+
)
163+
self.os_type: str = cluster_spec.workers.os_type
164+
self.num_workers: int = cluster_spec.workers.vm_count
165+
self.vms: List[linux_virtual_machine.BaseLinuxVirtualMachine] = []
166+
self.headnode_vm: linux_virtual_machine.BaseLinuxVirtualMachine | None = (
167+
None
168+
)
169+
self.worker_vms: List[linux_virtual_machine.BaseLinuxVirtualMachine] = []
170+
self.name: str = FLAGS.run_uri
171+
self.nfs_path: str = None
172+
173+
def GetResourceMetadata(self):
174+
return {
175+
'zone': self.zone,
176+
'machine_type': self.machine_type,
177+
'worker_machine_type': self.worker_machine_type,
178+
'headnode_machine_type': self.headnode_machine_type,
179+
'image': self.image,
180+
'os_type': self.os_type,
181+
'num_workers': self.num_workers,
182+
}
183+
184+
def __repr__(self):
185+
return f'<BaseCluster [name={self.name}]>'
186+
187+
# TODO(yuyanting) Move common logic here after having concrete implementation.
188+
def _RenderClusterConfig(self):
189+
"""Render the config file that will be used to create the cluster."""
190+
pass
191+
192+
def RemoteCommand(
193+
self,
194+
command: str,
195+
ignore_failure: bool = False,
196+
timeout: float | None = None,
197+
**kwargs,
198+
) -> Tuple[str, str]:
199+
"""Runs a command on the VM.
200+
201+
Derived classes may add additional kwargs if necessary, but they should not
202+
be used outside of the class itself since they are non standard.
203+
204+
Args:
205+
command: A valid bash command.
206+
ignore_failure: Ignore any failure if set to true.
207+
timeout: The time to wait in seconds for the command before exiting. None
208+
means no timeout.
209+
**kwargs: Additional command arguments.
210+
211+
Returns:
212+
A tuple of stdout and stderr from running the command.
213+
214+
Raises:
215+
RemoteCommandError: If there was a problem issuing the command.
216+
"""
217+
return self.headnode_vm.RemoteCommand(
218+
f'srun -N {self.num_workers} {command}',
219+
ignore_failure=ignore_failure,
220+
timeout=timeout,
221+
**kwargs,
222+
)
223+
224+
def RobustRemoteCommand(
225+
self,
226+
command: str,
227+
timeout: float | None = None,
228+
ignore_failure: bool = False,
229+
) -> Tuple[str, str]:
230+
"""Runs a command on the VM in a more robust way than RemoteCommand.
231+
232+
The default should be to call RemoteCommand and log that it is not yet
233+
implemented. This function should be overwritten it is decendents.
234+
235+
Args:
236+
command: The command to run.
237+
timeout: The timeout for the command in seconds.
238+
ignore_failure: Ignore any failure if set to true.
239+
240+
Returns:
241+
A tuple of stdout, stderr from running the command.
242+
243+
Raises:
244+
RemoteCommandError: If there was a problem establishing the connection, or
245+
the command fails.
246+
"""
247+
return self.headnode_vm.RobustRemoteCommand(
248+
command, ignore_failure=ignore_failure, timeout=timeout
249+
)
250+
251+
def TryRemoteCommand(self, command: str, **kwargs):
252+
"""Runs a remote command and returns True iff it succeeded."""
253+
try:
254+
self.RemoteCommand(command, **kwargs)
255+
return True
256+
except errors.VirtualMachine.RemoteCommandError:
257+
return False
258+
259+
def BackfillVm(
260+
self,
261+
vm_spec: virtual_machine.BaseVmSpec,
262+
fn: Callable[[virtual_machine.BaseVirtualMachine], None],
263+
):
264+
"""Create and backfill a VM object created using cluster resource.
265+
266+
Args:
267+
vm_spec: VM spec to be used to find corresponding VM class.
268+
fn: The function to be called on the newly created VM.
269+
270+
Returns:
271+
The newly created VM object.
272+
"""
273+
vm_class = virtual_machine.GetVmClass(vm_spec.CLOUD, self.os_type)
274+
vm = vm_class(vm_spec)
275+
fn(vm)
276+
vm.disks = []
277+
vm._PostCreate() # pylint: disable=protected-access
278+
vm.created = True
279+
return vm
280+
281+
def AuthenticateVM(self):
282+
"""Authenticate a remote machine to access all vms."""
283+
for vm in self.vms:
284+
vm.AuthenticateVm()
285+
286+
287+
Cluster = typing.TypeVar('Cluster', bound=BaseCluster)

0 commit comments

Comments
 (0)