Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions libs/mng_modal/imbue/mng_modal/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any
from typing import ClassVar
from typing import Final
from typing import assert_never

from loguru import logger
from pydantic import ConfigDict
Expand Down Expand Up @@ -42,6 +43,7 @@
from imbue.modal_proxy.interface import AppInterface
from imbue.modal_proxy.interface import ModalInterface
from imbue.modal_proxy.interface import VolumeInterface
from imbue.modal_proxy.testing import TestingModalInterface

MODAL_BACKEND_NAME: Final[ProviderBackendName] = ProviderBackendName("modal")
STATE_VOLUME_SUFFIX: Final[str] = "-state"
Expand Down Expand Up @@ -214,6 +216,7 @@ def _get_or_create_app(
environment_name: str,
is_persistent: bool,
modal_interface: ModalInterface,
is_testing: bool = False,
) -> tuple[AppInterface, ModalAppContextHandle]:
"""Get or create a Modal app with output capture.

Expand All @@ -238,10 +241,18 @@ def _get_or_create_app(
return cls._app_registry[app_name]

with log_span("Creating ephemeral Modal app with output capture: {} (env: {})", app_name, environment_name):
# Enter the output capture context first
with log_span("Enabling Modal output capture"):
output_capture_context = enable_modal_output_capture(is_logging_to_loguru=True)
output_buffer, loguru_writer = output_capture_context.__enter__()
# Testing mode uses a null context instead of Modal output capture,
# which requires Modal SDK internals not available in testing.
if is_testing:
output_buffer = StringIO()
loguru_writer: ModalLoguruWriter | None = None
output_capture_context: AbstractContextManager[tuple[StringIO, ModalLoguruWriter | None]] = (
contextlib.nullcontext((output_buffer, loguru_writer))
)
else:
with log_span("Enabling Modal output capture"):
output_capture_context = enable_modal_output_capture(is_logging_to_loguru=True)
output_buffer, loguru_writer = output_capture_context.__enter__()

if is_persistent:
with log_span("Looking up persistent Modal app: {}", app_name):
Expand Down Expand Up @@ -412,8 +423,15 @@ def build_provider_instance(
match config.mode:
case ModalMode.DIRECT:
modal_interface: ModalInterface = DirectModalInterface()
case _:
raise MngError(f"Unsupported modal mode: {config.mode}")
case ModalMode.TESTING:
testing_root = mng_ctx.profile_dir / "modal_testing"
testing_root.mkdir(parents=True, exist_ok=True)
modal_interface = TestingModalInterface(
root_dir=testing_root,
concurrency_group=mng_ctx.concurrency_group,
)
case _ as unreachable:
assert_never(unreachable)

# Use prefix + user_id for the environment name, ensuring isolation
# between different mng installations sharing the same Modal account.
Expand Down Expand Up @@ -444,7 +462,11 @@ def build_provider_instance(
# Create the ModalProviderApp that manages the Modal app and its resources
try:
app, context_handle = ModalProviderBackend._get_or_create_app(
app_name, environment_name, config.is_persistent, modal_interface
app_name,
environment_name,
config.is_persistent,
modal_interface,
is_testing=config.mode == ModalMode.TESTING,
)
volume = ModalProviderBackend.get_volume_for_app(app_name, modal_interface)

Expand Down
1 change: 1 addition & 0 deletions libs/mng_modal/imbue/mng_modal/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class ModalMode(UpperCaseStrEnum):
"""Mode for the Modal provider backend."""

DIRECT = auto()
TESTING = auto()


class ModalProviderConfig(ProviderInstanceConfig):
Expand Down
36 changes: 36 additions & 0 deletions libs/mng_modal/imbue/mng_modal/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@
from imbue.mng_modal.config import ModalProviderConfig
from imbue.mng_modal.constants import MODAL_TEST_APP_PREFIX
from imbue.mng_modal.instance import ModalProviderInstance
from imbue.mng_modal.testing import make_testing_modal_interface
from imbue.mng_modal.testing import make_testing_provider
from imbue.modal_proxy.testing import TestingModalInterface


def make_modal_provider_real(
Expand Down Expand Up @@ -446,3 +449,36 @@ def modal_session_cleanup() -> Generator[None, None, None]:
+ "\n\n".join(errors)
+ "\n\nThese resources have been cleaned up, but tests should not leak!\n"
)


# =============================================================================
# Testing Modal Interface fixtures
#
# These fixtures provide a ModalProviderInstance backed by TestingModalInterface
# for testing mng_modal business logic without Modal credentials or SSH.
# =============================================================================


@pytest.fixture
def testing_modal(tmp_path: Path, cg: ConcurrencyGroup) -> TestingModalInterface:
return make_testing_modal_interface(tmp_path, cg)


@pytest.fixture
def testing_provider(
temp_mng_ctx: MngContext,
testing_modal: TestingModalInterface,
) -> Generator[ModalProviderInstance, None, None]:
provider = make_testing_provider(temp_mng_ctx, testing_modal)
yield provider
testing_modal.cleanup()


@pytest.fixture
def testing_provider_no_host_volume(
temp_mng_ctx: MngContext,
testing_modal: TestingModalInterface,
) -> Generator[ModalProviderInstance, None, None]:
provider = make_testing_provider(temp_mng_ctx, testing_modal, is_host_volume_created=False)
yield provider
testing_modal.cleanup()
174 changes: 174 additions & 0 deletions libs/mng_modal/imbue/mng_modal/testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
"""Test utilities for mng_modal.

Non-fixture helpers for creating test objects. Fixtures that use these
helpers live in conftest.py.
"""

from datetime import datetime
from datetime import timezone
from pathlib import Path

from imbue.concurrency_group.concurrency_group import ConcurrencyGroup
from imbue.mng.config.data_types import MngContext
from imbue.mng.interfaces.data_types import CertifiedHostData
from imbue.mng.interfaces.data_types import SnapshotRecord
from imbue.mng.primitives import HostId
from imbue.mng.primitives import HostName
from imbue.mng.primitives import ProviderInstanceName
from imbue.mng_modal.backend import STATE_VOLUME_SUFFIX
from imbue.mng_modal.config import ModalMode
from imbue.mng_modal.config import ModalProviderConfig
from imbue.mng_modal.instance import HostRecord
from imbue.mng_modal.instance import ModalProviderApp
from imbue.mng_modal.instance import ModalProviderInstance
from imbue.mng_modal.instance import SandboxConfig
from imbue.mng_modal.instance import TAG_HOST_ID
from imbue.mng_modal.instance import TAG_HOST_NAME
from imbue.mng_modal.instance import TAG_USER_PREFIX
from imbue.modal_proxy.interface import SandboxInterface
from imbue.modal_proxy.testing import TestingModalInterface

_DEFAULT_SANDBOX_CONFIG = SandboxConfig()


def make_testing_modal_interface(tmp_path: Path, cg: ConcurrencyGroup) -> TestingModalInterface:
"""Create a TestingModalInterface rooted in a temp directory."""
root = tmp_path / "modal_testing"
root.mkdir(parents=True, exist_ok=True)
return TestingModalInterface(root_dir=root, concurrency_group=cg)


def make_testing_provider(
mng_ctx: MngContext,
modal_interface: TestingModalInterface,
app_name: str = "test-app",
is_persistent: bool = False,
is_snapshotted_after_create: bool = False,
is_host_volume_created: bool = True,
) -> ModalProviderInstance:
"""Create a ModalProviderInstance backed by TestingModalInterface."""
environment_name = f"{mng_ctx.config.prefix}test-user"

app = modal_interface.app_lookup(app_name, create_if_missing=True, environment_name=environment_name)
volume_name = f"{app_name}{STATE_VOLUME_SUFFIX}"
volume = modal_interface.volume_from_name(
volume_name,
create_if_missing=True,
environment_name=environment_name,
)

config = ModalProviderConfig(
mode=ModalMode.TESTING,
app_name=app_name,
host_dir=mng_ctx.config.default_host_dir,
default_sandbox_timeout=300,
default_cpu=0.5,
default_memory=0.5,
is_persistent=is_persistent,
is_snapshotted_after_create=is_snapshotted_after_create,
is_host_volume_created=is_host_volume_created,
)

modal_app = ModalProviderApp(
app_name=app_name,
environment_name=environment_name,
app=app,
volume=volume,
modal_interface=modal_interface,
close_callback=lambda: None,
get_output_callback=lambda: "",
)

return ModalProviderInstance(
name=ProviderInstanceName("modal-test"),
host_dir=mng_ctx.config.default_host_dir,
mng_ctx=mng_ctx,
config=config,
modal_app=modal_app,
)


def make_snapshot(snap_id: str = "snap-1", name: str = "s1") -> SnapshotRecord:
"""Create a SnapshotRecord for testing."""
return SnapshotRecord(id=snap_id, name=name, created_at=datetime.now(timezone.utc).isoformat())


def make_host_record(
host_id: HostId | None = None,
host_name: str = "test-host",
snapshots: list[SnapshotRecord] | None = None,
failure_reason: str | None = None,
user_tags: dict[str, str] | None = None,
config: SandboxConfig | None = _DEFAULT_SANDBOX_CONFIG,
ssh_host: str | None = "127.0.0.1",
ssh_port: int | None = 22222,
ssh_host_public_key: str | None = "ssh-ed25519 AAAA...",
) -> HostRecord:
"""Create a HostRecord for testing."""
if host_id is None:
host_id = HostId.generate()
now = datetime.now(timezone.utc)
certified_data = CertifiedHostData(
host_id=str(host_id),
host_name=host_name,
user_tags=user_tags or {},
snapshots=snapshots or [],
failure_reason=failure_reason,
created_at=now,
updated_at=now,
)
return HostRecord(
certified_host_data=certified_data,
ssh_host=ssh_host,
ssh_port=ssh_port,
ssh_host_public_key=ssh_host_public_key,
config=config,
)


def make_sandbox_with_tags(
modal_interface: TestingModalInterface,
host_id: HostId,
host_name: str,
user_tags: dict[str, str] | None = None,
) -> SandboxInterface:
"""Create a testing sandbox with mng tags set."""
image = modal_interface.image_debian_slim()
app = list(modal_interface._apps.values())[0]
sandbox = modal_interface.sandbox_create(
image=image,
app=app,
timeout=300,
cpu=1.0,
memory=1024,
)
tags: dict[str, str] = {
TAG_HOST_ID: str(host_id),
TAG_HOST_NAME: host_name,
}
if user_tags:
for key, value in user_tags.items():
tags[TAG_USER_PREFIX + key] = value
sandbox.set_tags(tags)
return sandbox


def setup_host_with_sandbox(
testing_provider: ModalProviderInstance,
testing_modal: TestingModalInterface,
host_name: str,
user_tags: dict[str, str] | None = None,
) -> tuple[HostId, HostRecord, SandboxInterface]:
"""Common setup: create a host record, sandbox with tags, and cache both.

Returns (host_id, record, sandbox). The host cache is populated with an
OfflineHost so that get_host() returns it without SSH.
"""
host_id = HostId.generate()
record = make_host_record(host_id=host_id, host_name=host_name, user_tags=user_tags)
testing_provider._write_host_record(record)
sandbox = make_sandbox_with_tags(testing_modal, host_id, host_name, user_tags=user_tags)
testing_provider._cache_sandbox(host_id, HostName(host_name), sandbox)
offline = testing_provider._create_host_from_host_record(record)
testing_provider._host_by_id_cache[host_id] = offline
return host_id, record, sandbox
Loading
Loading