Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,7 @@ assets/wheels/vllm*.whl
# DCP artifacts
forge_dcp_tmp/
demo_top_down.md


# enroot / sqsh
*.sqsh
3 changes: 3 additions & 0 deletions apps/vllm/llama3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,8 @@ services:
with_gpus: true


num_calls: 100
# llama3 8b is too dumb to write code
with_coder: false
# Optional, otherwise argparse fallback kicks in
prompt: "Tell me a joke"
69 changes: 63 additions & 6 deletions apps/vllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
import asyncio

import os
import re
import time

from forge.actors.coder import SandboxedCoder

from forge.actors.policy import Policy
from forge.cli.config import parse
Expand All @@ -25,22 +29,57 @@
os.environ["HYPERACTOR_CODE_MAX_FRAME_LENGTH"] = "1073741824"


def extract_python_code(text: str) -> str:
"""Extract Python code from ```python``` markdown code blocks.

Args:
text: Text that may contain Python code in markdown blocks

Returns:
Extracted Python code, or original text if no code blocks found
"""
# Look for ```python code blocks
pattern = r"```python(.*?)```"
matches = re.findall(pattern, text, re.DOTALL | re.IGNORECASE)

if matches:
# Return the first match, stripped of extra whitespace
return matches[0].strip()
else:
# If no python blocks found, return the original text (fallback)
return text.strip()


async def run(cfg: DictConfig):
metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
mlogger = await get_or_create_metric_logger()
await mlogger.init_backends.call_one(metric_logging_cfg)

if (prompt := cfg.get("prompt")) is None:
gd = cfg.policy.get("sampling_config", {}).get("guided_decoding", False)
prompt = "What is 3+5?" if gd else "Tell me a joke"
# Use different prompts based on whether we're using coder
if cfg.with_coder:
if (prompt := cfg.get("prompt")) is None:
prompt = "Write a Python function that calculates the factorial of a number and test it with factorial(5). Include the test call and print the result. Please wrap your code in ```python``` code blocks."
else:
if (prompt := cfg.get("prompt")) is None:
gd = cfg.policy.get("sampling_config", {}).get("guided_decoding", False)
prompt = "What is 3+5?" if gd else "Tell me a joke"

print("Spawning service...")
print("Spawning services...")
policy = await Policy.options(**cfg.services.policy).as_service(**cfg.policy)

import time
coder = None
with_coder = cfg.get("with_coder", False)
n = cfg.get("num_calls", 100)
if with_coder:
print("Setting up coder...")
coder = await SandboxedCoder.options(**cfg.services.coder).as_service(
docker_image="docker://python:3.10",
sqsh_image_path="python-image.sqsh",
container_name="sandbox",
)

print("Requesting generation...")
n = 100
n = cfg.num_calls
start = time.time()
response_outputs: list[Completion] = await asyncio.gather(
*[policy.generate.route(prompt=prompt) for _ in range(n)]
Expand All @@ -58,10 +97,28 @@ async def run(cfg: DictConfig):
print(f"Sample {batch + 1}:")
print(f"User: {prompt}")
print(f"Assistant: {response.text}")

# If we have a coder, try to execute the generated code
if coder and with_coder:
print(f"Parsing and executing generated code...")
try:
# Extract Python code from tags
python_code = extract_python_code(response.text)
print(f"Extracted Code:\n{python_code}")
print("-" * 40)

# Execute the extracted code
execution_result = await coder.execute.route(code=python_code)
print(f"Execution Output:\n{execution_result}")
except Exception as e:
print(f"Execution Error: {e}")

print("-" * 80)

print("\nShutting down...")
await policy.shutdown()
if coder:
await coder.shutdown()
await shutdown()


Expand Down
10 changes: 7 additions & 3 deletions apps/vllm/qwen2_5_32b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@ policy:
services:
policy:
procs: 4
hosts: 1
num_replicas: 1
with_gpus: true
coder:
procs: 1
num_replicas: 1
with_gpus: false


num_calls: 1
# Optional, otherwise argparse fallback kicks in
prompt: "Tell me a joke"
prompt: "Write a Python function that calculates the factorial of a number and test it with factorial(5). Please wrap your code in ```python``` code blocks."
with_coder: true
13 changes: 12 additions & 1 deletion src/forge/actors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

__all__ = ["Policy", "PolicyRouter", "RLTrainer", "ReplayBuffer", "TitanRefModel"]
__all__ = [
"Policy",
"PolicyRouter",
"RLTrainer",
"ReplayBuffer",
"TitanRefModel",
"SandboxedCoder",
]


def __getattr__(name):
Expand All @@ -28,5 +35,9 @@ def __getattr__(name):
from .reference_model import ReferenceModel

return ReferenceModel
elif name == "SandboxedCoder":
from .coder import SandboxedCoder

return SandboxedCoder
else:
raise AttributeError(f"module {__name__} has no attribute {name}")
145 changes: 145 additions & 0 deletions src/forge/actors/coder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
import os
import subprocess
import tempfile
from pathlib import Path

from monarch.actor import endpoint

from forge.controller import ForgeActor

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


class SandboxedCoder(ForgeActor):
"""A sandboxed code execution environment using enroot containers.
SandboxedCoder provides a secure, isolated environment for executing Python code
using NVIDIA's enroot containerization technology. It automatically manages the
entire container lifecycle including image import, container creation, and cleanup.
The actor follows a three-stage workflow:
1. Image Management: Automatically imports Docker images to enroot .sqsh format
2. Container Lifecycle: Creates fresh container instances for isolated execution
3. Code Execution: Safely runs Python code with proper error handling and output capture
Dependencies:
- enroot: NVIDIA's container runtime (must be installed on host)
- Docker images: Accessible via docker:// URLs or local paths
- Python 3.x: For the container environment
Args:
docker_image: Docker image URL to import (e.g., "docker://python:3.10").
Can be any Docker Hub image or custom registry URL.
sqsh_image_path: Local filesystem path where the enroot .sqsh image will be stored.
If the file doesn't exist, it will be created via enroot import.
container_name: Unique name for the enroot container instance. Used for
container lifecycle management (create/remove operations).
"""

def __init__(
self,
docker_image: str = "docker://python:3.10",
sqsh_image_path: str = "python-image.sqsh",
container_name: str = "sandbox",
):
self.docker_image = docker_image
self.sqsh_image_path = sqsh_image_path
self.container_name = container_name
self._initialized = False

@endpoint
async def setup(self):
logging.debug("Setting up sandboxed actor")
await self._ensure_image()
self._reset()

@endpoint
async def reset(self):
self._reset()

async def _ensure_image(self):
"""Ensure the enroot image exists, import it if necessary."""
if not os.path.exists(self.sqsh_image_path):
logging.debug(
f"Image {self.sqsh_image_path} not found, importing from {self.docker_image}"
)
result = subprocess.run(
["enroot", "import", "-o", self.sqsh_image_path, self.docker_image],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
if result.returncode != 0:
raise RuntimeError(f"Failed to import image: {result.stderr}")
logging.debug(
f"Successfully imported {self.docker_image} to {self.sqsh_image_path}"
)
else:
logging.info(f"Using existing image: {self.sqsh_image_path}")

def _reset(self):
"""(Re)create a clean container instance from the base image."""
# Remove any old container
logging.debug(f"Removing container {self.container_name}")
subprocess.run(
["enroot", "remove", "-f", self.container_name],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
# Create new container from image
result = subprocess.run(
["enroot", "create", "--name", self.container_name, self.sqsh_image_path],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
logging.debug(f"Container creation result: {result}")
if result.returncode != 0:
raise RuntimeError(f"Failed to reset container: {result.stderr}")
self._initialized = True
logging.debug("Successfully initialized container")

@endpoint
async def execute(self, code: str) -> str:
"""
Execute Python code inside the container.
:param code: Python source code string to execute.
:return: Captured stdout.
"""
logging.debug(f"Executing {code}")
if not self._initialized:
raise RuntimeError("Container not initialized. Call reset() first.")

# Write code to a temporary file that we can mount
with tempfile.TemporaryDirectory() as tmpdir:
code_path = Path(tmpdir) / "script.py"
code_path.write_text(code)

# Run the code inside the container, mounting tmpdir
cmd = [
"enroot",
"start",
"--mount",
f"{tmpdir}:/work",
self.container_name,
"python3",
"/work/script.py",
]
result = subprocess.run(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
if result.returncode != 0:
raise RuntimeError(f"Execution failed:\n{result.stderr}")
return result.stdout
Loading