Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion libs/mng/imbue/mng/api/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,16 @@
from imbue.mng.api.discovery_events import emit_discovery_events_for_host
from imbue.mng.api.providers import get_provider_instance
from imbue.mng.config.data_types import MngContext
from imbue.mng.errors import MngError
from imbue.mng.hosts.host import HostLocation
from imbue.mng.interfaces.host import CreateAgentOptions
from imbue.mng.interfaces.host import HostEnvironmentOptions
from imbue.mng.interfaces.host import NewHostOptions
from imbue.mng.interfaces.host import OnlineHostInterface
from imbue.mng.interfaces.provider_instance import ProviderInstanceInterface
from imbue.mng.plugins.hookspecs import OnBeforeCreateArgs
from imbue.mng.primitives import HostName
from imbue.mng.primitives import HostNameStyle
from imbue.mng.utils.env_utils import parse_env_file


Expand Down Expand Up @@ -186,6 +190,47 @@ def _write_host_env_vars(
host.set_env_vars(env_vars)


_MAX_AUTO_NAME_ATTEMPTS: int = 20


def _generate_unique_host_name(
provider: ProviderInstanceInterface,
style: HostNameStyle,
mng_ctx: MngContext,
) -> HostName:
"""Generate a host name that does not collide with any existing host on the provider.

Discovers existing hosts, then generates names until a unique one is found.
If the provider returns a fixed name (e.g. "localhost" for the local provider),
accepts it even if it matches an existing host, since such providers are designed
to reuse the same host.

Raises MngError if no unique name can be generated after _MAX_AUTO_NAME_ATTEMPTS attempts.
"""
with log_span("Discovering existing hosts for unique name generation"):
existing_hosts = provider.discover_hosts(cg=mng_ctx.concurrency_group)
existing_names = {h.host_name for h in existing_hosts}

first_candidate = provider.get_host_name(style)
if first_candidate not in existing_names:
return first_candidate

for _ in range(_MAX_AUTO_NAME_ATTEMPTS - 1):
candidate = provider.get_host_name(style)
if candidate not in existing_names:
return candidate
if candidate == first_candidate:
# The provider returns a fixed name (e.g. "localhost") rather than
# a random one. Accept it as-is -- such providers are designed to
# reuse the same host.
return candidate

raise MngError(
f"Failed to generate a unique host name after {_MAX_AUTO_NAME_ATTEMPTS} attempts. "
f"There are {len(existing_names)} existing hosts on provider '{provider.name}'."
)


def resolve_target_host(
target_host: OnlineHostInterface | NewHostOptions,
mng_ctx: MngContext,
Expand All @@ -195,7 +240,9 @@ def resolve_target_host(
# Create a new host using the specified provider
provider = get_provider_instance(target_host.provider, mng_ctx)
host_name = (
target_host.name if target_host.name is not None else provider.get_host_name(target_host.name_style)
target_host.name
if target_host.name is not None
else _generate_unique_host_name(provider, target_host.name_style, mng_ctx)
)

with log_span("Calling on_before_host_create hooks"):
Expand Down
132 changes: 132 additions & 0 deletions libs/mng/imbue/mng/api/create_test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,25 @@
from pathlib import Path

import pytest
from pydantic import Field
from pydantic import PrivateAttr

from imbue.concurrency_group.concurrency_group import ConcurrencyGroup
from imbue.mng.api.create import _generate_unique_host_name
from imbue.mng.api.create import _write_host_env_vars
from imbue.mng.api.create import resolve_target_host
from imbue.mng.config.data_types import EnvVar
from imbue.mng.config.data_types import MngContext
from imbue.mng.errors import MngError
from imbue.mng.hosts.host import Host
from imbue.mng.interfaces.host import HostEnvironmentOptions
from imbue.mng.interfaces.host import OnlineHostInterface
from imbue.mng.primitives import DiscoveredHost
from imbue.mng.primitives import HostId
from imbue.mng.primitives import HostName
from imbue.mng.primitives import HostNameStyle
from imbue.mng.primitives import ProviderInstanceName
from imbue.mng.providers.mock_provider_test import MockProviderInstance


def test_write_host_env_vars_writes_explicit_env_vars(
Expand Down Expand Up @@ -122,3 +135,122 @@ def test_write_host_env_vars_later_env_file_overrides_earlier(
assert host_env["SHARED"] == "from_second"
assert host_env["FIRST_ONLY"] == "present"
assert host_env["SECOND_ONLY"] == "present"


# =============================================================================
# _generate_unique_host_name Tests
# =============================================================================


class _SequentialNameProvider(MockProviderInstance):
"""Mock provider that returns names from a predefined sequence.

Also overrides discover_hosts to return a configurable set of discovered hosts
for uniqueness testing.
"""

sequential_names: tuple[HostName, ...] = Field(default=(), description="Names to return in sequence")
discovered_hosts_override: tuple[DiscoveredHost, ...] = Field(
default=(), description="Hosts to return from discover_hosts"
)
_call_count: int = PrivateAttr(default=0)

def get_host_name(self, style: HostNameStyle) -> HostName:
index = min(self._call_count, len(self.sequential_names) - 1)
self._call_count += 1
return self.sequential_names[index]

def discover_hosts(
self,
cg: ConcurrencyGroup,
include_destroyed: bool = False,
) -> list[DiscoveredHost]:
if self.discovered_hosts_override:
return list(self.discovered_hosts_override)
return super().discover_hosts(cg=cg, include_destroyed=include_destroyed)


def _make_discovered_host(name: str, provider_name: str = "test") -> DiscoveredHost:
return DiscoveredHost(
host_id=HostId.generate(),
host_name=HostName(name),
provider_name=ProviderInstanceName(provider_name),
)


def test_generate_unique_host_name_no_existing_hosts(
temp_mng_ctx: MngContext,
temp_host_dir: Path,
) -> None:
"""_generate_unique_host_name should return the first name when no hosts exist."""
provider = _SequentialNameProvider(
sequential_names=(HostName("alpha"),),
name=ProviderInstanceName("test"),
host_dir=temp_host_dir,
mng_ctx=temp_mng_ctx,
)

result = _generate_unique_host_name(provider, HostNameStyle.ASTRONOMY, temp_mng_ctx)
assert result == HostName("alpha")


def test_generate_unique_host_name_skips_colliding_names(
temp_mng_ctx: MngContext,
temp_host_dir: Path,
) -> None:
"""_generate_unique_host_name should skip names that collide with existing hosts."""
provider = _SequentialNameProvider(
sequential_names=(HostName("taken"), HostName("also-taken"), HostName("unique")),
discovered_hosts_override=(
_make_discovered_host("taken", "test"),
_make_discovered_host("also-taken", "test"),
),
name=ProviderInstanceName("test"),
host_dir=temp_host_dir,
mng_ctx=temp_mng_ctx,
)

result = _generate_unique_host_name(provider, HostNameStyle.ASTRONOMY, temp_mng_ctx)
assert result == HostName("unique")


def test_generate_unique_host_name_accepts_fixed_name_provider(
temp_mng_ctx: MngContext,
temp_host_dir: Path,
) -> None:
"""_generate_unique_host_name should accept a colliding name from a fixed-name provider.

Providers like the local provider always return the same name (e.g. "localhost")
even when a host with that name already exists. The function should detect this
and return the name as-is.
"""
provider = _SequentialNameProvider(
sequential_names=(HostName("localhost"), HostName("localhost")),
discovered_hosts_override=(_make_discovered_host("localhost", "test"),),
name=ProviderInstanceName("test"),
host_dir=temp_host_dir,
mng_ctx=temp_mng_ctx,
)

result = _generate_unique_host_name(provider, HostNameStyle.ASTRONOMY, temp_mng_ctx)
assert result == HostName("localhost")


def test_generate_unique_host_name_raises_after_max_attempts(
temp_mng_ctx: MngContext,
temp_host_dir: Path,
) -> None:
"""_generate_unique_host_name should raise MngError when all random names collide."""
# Each attempt returns a different name, but all collide with existing hosts.
# This ensures the fixed-name-provider detection doesn't short-circuit.
taken_names = tuple(HostName(f"taken-{i}") for i in range(20))
provider = _SequentialNameProvider(
sequential_names=taken_names,
discovered_hosts_override=tuple(_make_discovered_host(str(n), "test") for n in taken_names),
name=ProviderInstanceName("test"),
host_dir=temp_host_dir,
mng_ctx=temp_mng_ctx,
)

with pytest.raises(MngError, match="Failed to generate a unique host name"):
_generate_unique_host_name(provider, HostNameStyle.ASTRONOMY, temp_mng_ctx)
Loading