|
2 | 2 | import random |
3 | 3 | import shlex |
4 | 4 | import time |
5 | | -from dataclasses import dataclass |
6 | 5 | from functools import cached_property |
7 | 6 | from typing import List, Optional |
8 | 7 |
|
|
21 | 20 | ) |
22 | 21 | from dstack._internal.core.backends.base.offers import get_catalog_offers |
23 | 22 | from dstack._internal.core.backends.nebius import resources |
| 23 | +from dstack._internal.core.backends.nebius.fabrics import get_suitable_infiniband_fabrics |
24 | 24 | from dstack._internal.core.backends.nebius.models import NebiusConfig, NebiusServiceAccountCreds |
25 | 25 | from dstack._internal.core.errors import ( |
26 | 26 | BackendError, |
|
81 | 81 | ] |
82 | 82 |
|
83 | 83 |
|
84 | | -@dataclass(frozen=True) |
85 | | -class InfinibandFabric: |
86 | | - name: str |
87 | | - platform: str |
88 | | - region: str |
89 | | - |
90 | | - |
91 | | -# https://docs.nebius.com/compute/clusters/gpu#fabrics |
92 | | -INFINIBAND_FABRICS = [ |
93 | | - InfinibandFabric("fabric-2", "gpu-h100-sxm", "eu-north1"), |
94 | | - InfinibandFabric("fabric-3", "gpu-h100-sxm", "eu-north1"), |
95 | | - InfinibandFabric("fabric-4", "gpu-h100-sxm", "eu-north1"), |
96 | | - InfinibandFabric("fabric-5", "gpu-h200-sxm", "eu-west1"), |
97 | | - InfinibandFabric("fabric-6", "gpu-h100-sxm", "eu-north1"), |
98 | | - InfinibandFabric("fabric-7", "gpu-h200-sxm", "eu-north1"), |
99 | | -] |
100 | | - |
101 | | - |
102 | 84 | class NebiusCompute( |
103 | 85 | ComputeWithCreateInstanceSupport, |
104 | 86 | ComputeWithMultinodeSupport, |
@@ -280,7 +262,9 @@ def create_placement_group( |
280 | 262 | backend_data = NebiusPlacementGroupBackendData(cluster=None) |
281 | 263 | # Only create a Nebius cluster if the instance supports it. |
282 | 264 | # For other instances, return dummy PlacementGroupProvisioningData. |
283 | | - if fabrics := _get_suitable_infiniband_fabrics(master_instance_offer): |
| 265 | + if fabrics := get_suitable_infiniband_fabrics( |
| 266 | + master_instance_offer, allowed_fabrics=self.config.fabrics |
| 267 | + ): |
284 | 268 | fabric = random.choice(fabrics) |
285 | 269 | op = resources.create_cluster( |
286 | 270 | self._sdk, |
@@ -319,7 +303,11 @@ def is_suitable_placement_group( |
319 | 303 | ) |
320 | 304 | return ( |
321 | 305 | backend_data.cluster is None |
322 | | - or backend_data.cluster.fabric in _get_suitable_infiniband_fabrics(instance_offer) |
| 306 | + or backend_data.cluster.fabric |
| 307 | + in get_suitable_infiniband_fabrics( |
| 308 | + instance_offer, |
| 309 | + allowed_fabrics=None, # enforced at cluster creation time, no need to enforce here |
| 310 | + ) |
323 | 311 | ) |
324 | 312 |
|
325 | 313 |
|
@@ -380,15 +368,3 @@ def _wait_for_instance(sdk: SDK, op: SDKOperation[Operation]) -> None: |
380 | 368 | def _supported_instances(offer: InstanceOffer) -> bool: |
381 | 369 | platform, _ = offer.instance.name.split() |
382 | 370 | return platform in SUPPORTED_PLATFORMS and not offer.instance.resources.spot |
383 | | - |
384 | | - |
385 | | -def _get_suitable_infiniband_fabrics(offer: InstanceOffer) -> list[str]: |
386 | | - if len(offer.instance.resources.gpus) < 8: |
387 | | - # From the create VM page in the Nebius Console: |
388 | | - # > Only virtual machines with at least 8 NVIDIA® Hopper® H100 or H200 GPUs |
389 | | - # > can be added to the cluster |
390 | | - return [] |
391 | | - platform, _ = offer.instance.name.split() |
392 | | - return [ |
393 | | - f.name for f in INFINIBAND_FABRICS if f.platform == platform and f.region == offer.region |
394 | | - ] |
0 commit comments