Skip to content

Commit 0a35951

Browse files
authored
Merge pull request #2396 from dstackai/issue_2372_common_run_job
Provide default run_job implementation for VM backends
2 parents ef720c2 + 564918c commit 0a35951

File tree

13 files changed

+67
-254
lines changed

13 files changed

+67
-254
lines changed

src/dstack/_internal/core/backends/aws/compute.py

Lines changed: 2 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
generate_unique_instance_name,
2323
generate_unique_volume_name,
2424
get_gateway_user_data,
25-
get_job_instance_name,
2625
get_user_data,
2726
merge_tags,
2827
)
@@ -39,11 +38,10 @@
3938
InstanceConfiguration,
4039
InstanceOffer,
4140
InstanceOfferWithAvailability,
42-
SSHKey,
4341
)
4442
from dstack._internal.core.models.placement import PlacementGroup, PlacementGroupProvisioningData
4543
from dstack._internal.core.models.resources import Memory, Range
46-
from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run
44+
from dstack._internal.core.models.runs import JobProvisioningData, Requirements
4745
from dstack._internal.core.models.volumes import (
4846
Volume,
4947
VolumeAttachmentData,
@@ -69,14 +67,14 @@ class AWSVolumeBackendData(CoreModel):
6967

7068

7169
class AWSCompute(
72-
Compute,
7370
ComputeWithCreateInstanceSupport,
7471
ComputeWithMultinodeSupport,
7572
ComputeWithReservationSupport,
7673
ComputeWithPlacementGroupSupport,
7774
ComputeWithGatewaySupport,
7875
ComputeWithPrivateGatewaySupport,
7976
ComputeWithVolumeSupport,
77+
Compute,
8078
):
8179
def __init__(self, config: AWSConfig):
8280
super().__init__()
@@ -285,44 +283,6 @@ def create_instance(
285283
continue
286284
raise NoCapacityError()
287285

288-
def run_job(
289-
self,
290-
run: Run,
291-
job: Job,
292-
instance_offer: InstanceOfferWithAvailability,
293-
project_ssh_public_key: str,
294-
project_ssh_private_key: str,
295-
volumes: List[Volume],
296-
) -> JobProvisioningData:
297-
# TODO: run_job is the same for vm-based backends, refactor
298-
instance_config = InstanceConfiguration(
299-
project_name=run.project_name,
300-
instance_name=get_job_instance_name(run, job), # TODO: generate name
301-
ssh_keys=[
302-
SSHKey(public=project_ssh_public_key.strip()),
303-
],
304-
user=run.user,
305-
volumes=volumes,
306-
reservation=run.run_spec.configuration.reservation,
307-
)
308-
instance_offer = instance_offer.copy()
309-
if len(volumes) > 0:
310-
volume = volumes[0]
311-
if (
312-
volume.provisioning_data is not None
313-
and volume.provisioning_data.availability_zone is not None
314-
):
315-
if instance_offer.availability_zones is None:
316-
instance_offer.availability_zones = [
317-
volume.provisioning_data.availability_zone
318-
]
319-
instance_offer.availability_zones = [
320-
z
321-
for z in instance_offer.availability_zones
322-
if z == volume.provisioning_data.availability_zone
323-
]
324-
return self.create_instance(instance_offer, instance_config)
325-
326286
def create_placement_group(
327287
self,
328288
placement_group: PlacementGroup,

src/dstack/_internal/core/backends/azure/compute.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
generate_unique_gateway_instance_name,
4646
generate_unique_instance_name,
4747
get_gateway_user_data,
48-
get_job_instance_name,
4948
get_user_data,
5049
merge_tags,
5150
)
@@ -62,11 +61,9 @@
6261
InstanceOffer,
6362
InstanceOfferWithAvailability,
6463
InstanceType,
65-
SSHKey,
6664
)
6765
from dstack._internal.core.models.resources import Memory, Range
68-
from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run
69-
from dstack._internal.core.models.volumes import Volume
66+
from dstack._internal.core.models.runs import JobProvisioningData, Requirements
7067
from dstack._internal.utils.logging import get_logger
7168

7269
logger = get_logger(__name__)
@@ -75,10 +72,10 @@
7572

7673

7774
class AzureCompute(
78-
Compute,
7975
ComputeWithCreateInstanceSupport,
8076
ComputeWithMultinodeSupport,
8177
ComputeWithGatewaySupport,
78+
Compute,
8279
):
8380
def __init__(self, config: AzureConfig, credential: TokenCredential):
8481
super().__init__()
@@ -198,25 +195,6 @@ def create_instance(
198195
backend_data=None,
199196
)
200197

201-
def run_job(
202-
self,
203-
run: Run,
204-
job: Job,
205-
instance_offer: InstanceOfferWithAvailability,
206-
project_ssh_public_key: str,
207-
project_ssh_private_key: str,
208-
volumes: List[Volume],
209-
) -> JobProvisioningData:
210-
instance_config = InstanceConfiguration(
211-
project_name=run.project_name,
212-
instance_name=get_job_instance_name(run, job), # TODO: generate name
213-
ssh_keys=[
214-
SSHKey(public=project_ssh_public_key.strip()),
215-
],
216-
user=run.user,
217-
)
218-
return self.create_instance(instance_offer, instance_config)
219-
220198
def terminate_instance(
221199
self, instance_id: str, region: str, backend_data: Optional[str] = None
222200
):

src/dstack/_internal/core/backends/base/compute.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from dstack._internal.core.models.instances import (
2626
InstanceConfiguration,
2727
InstanceOfferWithAvailability,
28+
SSHKey,
2829
)
2930
from dstack._internal.core.models.placement import PlacementGroup, PlacementGroupProvisioningData
3031
from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run
@@ -144,6 +145,51 @@ def create_instance(
144145
"""
145146
pass
146147

148+
def run_job(
149+
self,
150+
run: Run,
151+
job: Job,
152+
instance_offer: InstanceOfferWithAvailability,
153+
project_ssh_public_key: str,
154+
project_ssh_private_key: str,
155+
volumes: List[Volume],
156+
) -> JobProvisioningData:
157+
"""
158+
The default `run_job()` implementation for all backends that support `create_instance()`.
159+
Override only if custom `run_job()` behavior is required.
160+
"""
161+
instance_config = InstanceConfiguration(
162+
project_name=run.project_name,
163+
instance_name=get_job_instance_name(run, job),
164+
user=run.user,
165+
ssh_keys=[SSHKey(public=project_ssh_public_key.strip())],
166+
volumes=volumes,
167+
reservation=run.run_spec.configuration.reservation,
168+
)
169+
instance_offer = instance_offer.copy()
170+
self._restrict_instance_offer_az_to_volumes_az(instance_offer, volumes)
171+
return self.create_instance(instance_offer, instance_config)
172+
173+
def _restrict_instance_offer_az_to_volumes_az(
174+
self,
175+
instance_offer: InstanceOfferWithAvailability,
176+
volumes: List[Volume],
177+
):
178+
if len(volumes) == 0:
179+
return
180+
volume = volumes[0]
181+
if (
182+
volume.provisioning_data is not None
183+
and volume.provisioning_data.availability_zone is not None
184+
):
185+
if instance_offer.availability_zones is None:
186+
instance_offer.availability_zones = [volume.provisioning_data.availability_zone]
187+
instance_offer.availability_zones = [
188+
z
189+
for z in instance_offer.availability_zones
190+
if z == volume.provisioning_data.availability_zone
191+
]
192+
147193

148194
class ComputeWithMultinodeSupport:
149195
"""

src/dstack/_internal/core/backends/cudo/compute.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from dstack._internal.core.backends.base.compute import (
77
ComputeWithCreateInstanceSupport,
88
generate_unique_instance_name,
9-
get_job_instance_name,
109
get_shim_commands,
1110
)
1211
from dstack._internal.core.backends.base.offers import get_catalog_offers
@@ -18,10 +17,8 @@
1817
InstanceAvailability,
1918
InstanceConfiguration,
2019
InstanceOfferWithAvailability,
21-
SSHKey,
2220
)
23-
from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run
24-
from dstack._internal.core.models.volumes import Volume
21+
from dstack._internal.core.models.runs import JobProvisioningData, Requirements
2522
from dstack._internal.utils.logging import get_logger
2623

2724
logger = get_logger(__name__)
@@ -31,8 +28,8 @@
3128

3229

3330
class CudoCompute(
34-
Compute,
3531
ComputeWithCreateInstanceSupport,
32+
Compute,
3633
):
3734
def __init__(self, config: CudoConfig):
3835
super().__init__()
@@ -55,25 +52,6 @@ def get_offers(
5552
]
5653
return offers
5754

58-
def run_job(
59-
self,
60-
run: Run,
61-
job: Job,
62-
instance_offer: InstanceOfferWithAvailability,
63-
project_ssh_public_key: str,
64-
project_ssh_private_key: str,
65-
volumes: List[Volume],
66-
) -> JobProvisioningData:
67-
instance_config = InstanceConfiguration(
68-
project_name=run.project_name,
69-
instance_name=get_job_instance_name(run, job),
70-
ssh_keys=[
71-
SSHKey(public=project_ssh_public_key.strip()),
72-
],
73-
user=run.user,
74-
)
75-
return self.create_instance(instance_offer, instance_config)
76-
7755
def create_instance(
7856
self,
7957
instance_offer: InstanceOfferWithAvailability,

src/dstack/_internal/core/backends/datacrunch/compute.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,9 @@
1515
InstanceConfiguration,
1616
InstanceOffer,
1717
InstanceOfferWithAvailability,
18-
SSHKey,
1918
)
2019
from dstack._internal.core.models.resources import Memory, Range
21-
from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run
22-
from dstack._internal.core.models.volumes import Volume
20+
from dstack._internal.core.models.runs import JobProvisioningData, Requirements
2321
from dstack._internal.utils.logging import get_logger
2422

2523
logger = get_logger("datacrunch.compute")
@@ -35,8 +33,8 @@
3533

3634

3735
class DataCrunchCompute(
38-
Compute,
3936
ComputeWithCreateInstanceSupport,
37+
Compute,
4038
):
4139
def __init__(self, config: DataCrunchConfig):
4240
super().__init__()
@@ -152,25 +150,6 @@ def create_instance(
152150
backend_data=None,
153151
)
154152

155-
def run_job(
156-
self,
157-
run: Run,
158-
job: Job,
159-
instance_offer: InstanceOfferWithAvailability,
160-
project_ssh_public_key: str,
161-
project_ssh_private_key: str,
162-
volumes: List[Volume],
163-
) -> JobProvisioningData:
164-
instance_config = InstanceConfiguration(
165-
project_name=run.project_name,
166-
instance_name=job.job_spec.job_name, # TODO: generate name
167-
ssh_keys=[
168-
SSHKey(public=project_ssh_public_key.strip()),
169-
],
170-
user=run.user,
171-
)
172-
return self.create_instance(instance_offer, instance_config)
173-
174153
def terminate_instance(
175154
self, instance_id: str, region: str, backend_data: Optional[str] = None
176155
) -> None:

src/dstack/_internal/core/backends/gcp/compute.py

Lines changed: 2 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
generate_unique_instance_name,
2121
generate_unique_volume_name,
2222
get_gateway_user_data,
23-
get_job_instance_name,
2423
get_shim_commands,
2524
get_user_data,
2625
merge_tags,
@@ -46,10 +45,9 @@
4645
InstanceOfferWithAvailability,
4746
InstanceType,
4847
Resources,
49-
SSHKey,
5048
)
5149
from dstack._internal.core.models.resources import Memory, Range
52-
from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run
50+
from dstack._internal.core.models.runs import JobProvisioningData, Requirements
5351
from dstack._internal.core.models.volumes import (
5452
Volume,
5553
VolumeAttachmentData,
@@ -74,11 +72,11 @@ class GCPVolumeDiskBackendData(CoreModel):
7472

7573

7674
class GCPCompute(
77-
Compute,
7875
ComputeWithCreateInstanceSupport,
7976
ComputeWithMultinodeSupport,
8077
ComputeWithGatewaySupport,
8178
ComputeWithVolumeSupport,
79+
Compute,
8280
):
8381
def __init__(self, config: GCPConfig):
8482
super().__init__()
@@ -373,44 +371,6 @@ def update_provisioning_data(
373371
f"Failed to get instance IP address. Instance status: {instance.status}"
374372
)
375373

376-
def run_job(
377-
self,
378-
run: Run,
379-
job: Job,
380-
instance_offer: InstanceOfferWithAvailability,
381-
project_ssh_public_key: str,
382-
project_ssh_private_key: str,
383-
volumes: List[Volume],
384-
) -> JobProvisioningData:
385-
# TODO: run_job is the same for vm-based backends, refactor
386-
instance_config = InstanceConfiguration(
387-
project_name=run.project_name,
388-
instance_name=get_job_instance_name(run, job), # TODO: generate name
389-
ssh_keys=[
390-
SSHKey(public=project_ssh_public_key.strip()),
391-
],
392-
user=run.user,
393-
volumes=volumes,
394-
reservation=run.run_spec.configuration.reservation,
395-
)
396-
instance_offer = instance_offer.copy()
397-
if len(volumes) > 0:
398-
volume = volumes[0]
399-
if (
400-
volume.provisioning_data is not None
401-
and volume.provisioning_data.availability_zone is not None
402-
):
403-
if instance_offer.availability_zones is None:
404-
instance_offer.availability_zones = [
405-
volume.provisioning_data.availability_zone
406-
]
407-
instance_offer.availability_zones = [
408-
z
409-
for z in instance_offer.availability_zones
410-
if z == volume.provisioning_data.availability_zone
411-
]
412-
return self.create_instance(instance_offer, instance_config)
413-
414374
def create_gateway(
415375
self,
416376
configuration: GatewayComputeConfiguration,

src/dstack/_internal/core/backends/kubernetes/compute.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@
5858

5959

6060
class KubernetesCompute(
61-
Compute,
6261
ComputeWithGatewaySupport,
62+
Compute,
6363
):
6464
def __init__(self, config: KubernetesConfig):
6565
super().__init__()

0 commit comments

Comments
 (0)