Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,7 @@ assets/wheels/vllm*.whl
# DCP artifacts
forge_dcp_tmp/
demo_top_down.md


# enroot / sqsh
*.sqsh
13 changes: 12 additions & 1 deletion src/forge/actors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

__all__ = ["Policy", "PolicyRouter", "RLTrainer", "ReplayBuffer", "TitanRefModel"]
__all__ = [
"Policy",
"PolicyRouter",
"RLTrainer",
"ReplayBuffer",
"TitanRefModel",
"SandboxedCoder",
]


def __getattr__(name):
Expand All @@ -28,5 +35,9 @@ def __getattr__(name):
from .reference_model import ReferenceModel

return ReferenceModel
elif name == "SandboxedCoder":
from .coder import SandboxedCoder

return SandboxedCoder
else:
raise AttributeError(f"module {__name__} has no attribute {name}")
145 changes: 145 additions & 0 deletions src/forge/actors/coder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
import os
import subprocess
import tempfile
from pathlib import Path

from monarch.actor import endpoint

from forge.controller import ForgeActor

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


class SandboxedCoder(ForgeActor):
"""A sandboxed code execution environment using enroot containers.
SandboxedCoder provides a secure, isolated environment for executing Python code
using NVIDIA's enroot containerization technology. It automatically manages the
entire container lifecycle including image import, container creation, and cleanup.
The actor follows a three-stage workflow:
1. Image Management: Automatically imports Docker images to enroot .sqsh format
2. Container Lifecycle: Creates fresh container instances for isolated execution
3. Code Execution: Safely runs Python code with proper error handling and output capture
Dependencies:
- enroot: NVIDIA's container runtime (must be installed on host)
- Docker images: Accessible via docker:// URLs or local paths
- Python 3.x: For the container environment
Args:
docker_image: Docker image URL to import (e.g., "docker://python:3.10").
Can be any Docker Hub image or custom registry URL.
sqsh_image_path: Local filesystem path where the enroot .sqsh image will be stored.
If the file doesn't exist, it will be created via enroot import.
container_name: Unique name for the enroot container instance. Used for
container lifecycle management (create/remove operations).
"""

def __init__(
self,
docker_image: str = "docker://python:3.10",
sqsh_image_path: str = "python-image.sqsh",
container_name: str = "sandbox",
):
self.docker_image = docker_image
self.sqsh_image_path = sqsh_image_path
self.container_name = container_name
self._initialized = False

@endpoint
async def setup(self):
logging.debug("Setting up sandboxed actor")
await self._ensure_image()
self._reset()

@endpoint
async def reset(self):
self._reset()

async def _ensure_image(self):
"""Ensure the enroot image exists, import it if necessary."""
if not os.path.exists(self.sqsh_image_path):
logging.debug(
f"Image {self.sqsh_image_path} not found, importing from {self.docker_image}"
)
result = subprocess.run(
["enroot", "import", "-o", self.sqsh_image_path, self.docker_image],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
if result.returncode != 0:
raise RuntimeError(f"Failed to import image: {result.stderr}")
logging.debug(
f"Successfully imported {self.docker_image} to {self.sqsh_image_path}"
)
else:
logging.info(f"Using existing image: {self.sqsh_image_path}")

def _reset(self):
"""(Re)create a clean container instance from the base image."""
# Remove any old container
logging.debug(f"Removing container {self.container_name}")
subprocess.run(
["enroot", "remove", "-f", self.container_name],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
# Create new container from image
result = subprocess.run(
["enroot", "create", "--name", self.container_name, self.sqsh_image_path],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
logging.debug(f"Container creation result: {result}")
if result.returncode != 0:
raise RuntimeError(f"Failed to reset container: {result.stderr}")
self._initialized = True
logging.debug("Successfully initialized container")

@endpoint
async def execute(self, code: str) -> str:
"""
Execute Python code inside the container.
:param code: Python source code string to execute.
:return: Captured stdout.
"""
logging.debug(f"Executing {code}")
if not self._initialized:
raise RuntimeError("Container not initialized. Call reset() first.")

# Write code to a temporary file that we can mount
with tempfile.TemporaryDirectory() as tmpdir:
code_path = Path(tmpdir) / "script.py"
code_path.write_text(code)

# Run the code inside the container, mounting tmpdir
cmd = [
"enroot",
"start",
"--mount",
f"{tmpdir}:/work",
self.container_name,
"python3",
"/work/script.py",
]
result = subprocess.run(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
if result.returncode != 0:
raise RuntimeError(f"Execution failed:\n{result.stderr}")
return result.stdout
6 changes: 6 additions & 0 deletions src/forge/controller/provisioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,12 @@ async def host_mesh_from_proc(self, proc_mesh: ProcMesh):

async def stop_proc_mesh(self, proc_mesh: ProcMesh):
"""Stops a proc mesh."""
if proc_mesh not in self._proc_host_map:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this in here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for some reason a bunch of unit tests were failing due to wrong cleanup. Not sure why I saw this in this PR specifically, but we should have this line anyways. If we want to be really clean I can add this in a separate PR

logger.warning(
f"proc mesh {proc_mesh} was requested to be stopped, but was either already stopped or "
"was never registered with the provisioner."
)
return
async with self._lock:
# Deregister local logger from global logger
if hasattr(proc_mesh, "_local_fetcher"):
Expand Down
52 changes: 52 additions & 0 deletions tests/integration_tests/test_coder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Integration tests for forge.actors.coder.SandboxedCoder.
Requires enroot to be installed.
"""

import os
import uuid

import pytest

from forge.actors.coder import SandboxedCoder


@pytest.mark.timeout(30)
@pytest.mark.asyncio
async def test_coder_runs_python():
"""Integration test for SandboxedCoder with real container execution."""
# Create unique names to avoid test conflicts
unique_id = str(uuid.uuid4())[:8]
container_name = f"test_sandbox_{unique_id}"
image_path = f"/tmp/python_test_{unique_id}.sqsh"

coder = None
try:
coder = await SandboxedCoder.as_actor(
docker_image="docker://python:3.10",
sqsh_image_path=image_path,
container_name=container_name,
)

# Execute code
results = await coder.execute.call_one(
code="print('hello world')",
)
assert results == "hello world\n"

finally:
# Clean up resources
if coder:
await SandboxedCoder.shutdown(coder)

# Clean up the image file
if os.path.exists(image_path):
os.unlink(image_path)
87 changes: 87 additions & 0 deletions tests/unit_tests/test_coder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Unit tests for forge.actors.coder.SandboxedCoder.
"""
import os
import tempfile
import uuid
from unittest.mock import Mock, patch

import pytest
from forge.actors.coder import SandboxedCoder

from monarch.actor import this_proc


@pytest.mark.timeout(10)
@pytest.mark.asyncio
async def test_coder_execution():
"""Tests basic coder execution with mocked enroot."""
unique_id = str(uuid.uuid4())[:8]
container_name = f"test_sandbox_{unique_id}"

with tempfile.NamedTemporaryFile(suffix=".sqsh", delete=False) as temp_image:
image_path = temp_image.name

coder = None
try:
with patch("subprocess.run") as mock_run:

def mock_subprocess_run(*args, **kwargs):
# Figure out which call this is based on the command
cmd = args[0]
if "import" in cmd:
result = Mock()
result.returncode = 0
result.stderr = ""
return result
elif "remove" in cmd:
result = Mock()
result.returncode = 0
return result
elif "create" in cmd:
result = Mock()
result.returncode = 0
result.stderr = ""
return result
elif "start" in cmd:
result = Mock()
result.returncode = 0
result.stdout = "hello world\n"
result.stderr = ""
print(f"Mock execute result: stdout = {repr(result.stdout)}")
return result
else:
raise ValueError(f"Unexpected subprocess call: {cmd}")

mock_run.side_effect = mock_subprocess_run

coder = this_proc().spawn(
"coder",
SandboxedCoder,
"docker://python:3.10",
image_path,
container_name,
)
await coder.setup.call_one()

# Execute code (this will trigger more mocked subprocess calls)
results = await coder.execute.call_one(
code="print('hello world')",
)

# Verify the result
assert results == "hello world\n"

finally:
# Clean up resources
if coder:
await SandboxedCoder.shutdown(coder)

if os.path.exists(image_path):
os.unlink(image_path)
Loading