Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 126 additions & 13 deletions src/prime_rl/utils/monitor/prime.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import json
import os
import time
from pathlib import Path
Expand Down Expand Up @@ -74,6 +75,7 @@ def __init__(
if config is not None and config.log_extras:
if config.log_extras.samples:
self.last_log_samples_step = -1
self._pending_sample_steps: set[int] = set()
self.tokenizer = tokenizer
if config.log_extras.distributions:
self.last_log_distributions_step = -1
Expand All @@ -93,7 +95,7 @@ def log(self, metrics: dict[str, Any], step: int | None = None) -> None:
)

def log_samples(self, rollouts: list[vf.State], step: int) -> None:
"""Logs rollouts to Prime Intellect API."""
"""Logs rollouts to Prime Intellect API using presigned URLs for direct R2 upload."""
if not self.is_master:
return
if not self.enabled:
Expand All @@ -108,12 +110,30 @@ def log_samples(self, rollouts: list[vf.State], step: int) -> None:
return

assert self.last_log_samples_step <= step, "Step must be greater than last logged step"
assert step not in self._pending_sample_steps, f"Step {step} upload already in progress"
assert self.logger is not None, "Logger is required for sample logging"

self.logger.info(f"Logging samples to Prime Intellect API at step {step}")
start_time = time.perf_counter()

# Prepare samples for API
samples = self._prepare_samples(rollouts, step)

if not samples:
self.logger.warning(f"No samples to log at step {step}")
return

self._pending_sample_steps.add(step)

# Use presigned URL flow for uploading samples
self._upload_samples_via_presigned_url(samples, step)

self.logger.debug(
f"Initiated samples upload at step {step} to Prime Intellect API in {time.perf_counter() - start_time:.2f}s"
)

def _prepare_samples(self, rollouts: list[vf.State], step: int) -> list[dict[str, Any]]:
"""Prepare samples from rollouts for upload."""
samples = []
for rollout in rollouts:
# Extract prompt and completion separately from the last trajectory step
Expand Down Expand Up @@ -164,19 +184,110 @@ def log_samples(self, rollouts: list[vf.State], step: int) -> None:
}
samples.append(sample)

# Upload samples
self._make_request(
"samples",
{
"run_id": self.run_id,
"step": step,
"samples": samples,
},
)
self.last_log_samples_step = step
self.logger.debug(
f"Logged samples at step {step} to Prime Intellect API in {time.perf_counter() - start_time:.2f}s"
return samples

def _upload_samples_via_presigned_url(self, samples: list[dict[str, Any]], step: int) -> None:
"""Upload samples using presigned URL flow (fire-and-forget)."""
future = asyncio.run_coroutine_threadsafe(
self._upload_samples_via_presigned_url_async(samples, step),
self._loop,
)
self._pending_futures.append(future)
# Clean up completed futures to avoid memory growth
self._pending_futures = [f for f in self._pending_futures if not f.done()]

async def _upload_samples_via_presigned_url_async(
self, samples: list[dict[str, Any]], step: int, max_retries: int = 3
) -> None:
"""Upload samples via presigned URL flow."""
try:
presign_data = await self._request_presigned_url(step, len(samples))
if not presign_data:
self.logger.warning(f"Failed to get presigned URL for samples at step {step}")
return

if "presigned_url" not in presign_data or "s3_key" not in presign_data:
self.logger.warning(f"Invalid presign response at step {step}")
return

presigned_url = presign_data["presigned_url"]
s3_key = presign_data["s3_key"]
json_bytes = json.dumps(samples).encode("utf-8")

upload_success = await self._upload_to_r2(
presigned_url, json_bytes, content_type="application/json", max_retries=max_retries
)
if not upload_success:
self.logger.warning(f"Failed to upload samples to R2 at step {step}")
return

confirm_success = await self._confirm_samples_upload(step, s3_key, len(samples))
if not confirm_success:
self.logger.warning(f"Failed to confirm samples upload at step {step}")
return

self.last_log_samples_step = step
self.logger.debug(f"Successfully completed samples upload at step {step}")

except Exception as e:
self.logger.warning(f"Failed to upload samples via presigned URL at step {step}: {type(e).__name__}: {e}")
finally:
self._pending_sample_steps.discard(step)

async def _request_presigned_url(self, step: int, sample_count: int) -> dict[str, Any] | None:
"""Request a presigned URL from the backend."""
headers = {"x-api-key": self.api_key, "Content-Type": "application/json"}
try:
response = await self._client.post(
f"{self.base_url}/samples/presign",
headers=headers,
json={"run_id": self.run_id, "step": step, "sample_count": sample_count},
)
response.raise_for_status()
return response.json()
except Exception as e:
self.logger.warning(f"Failed to request presigned URL: {type(e).__name__}: {e}")
return None

async def _upload_to_r2(
self, presigned_url: str, data: bytes, content_type: str = "application/json", max_retries: int = 3
) -> bool:
"""Upload data to R2 using presigned URL."""
for attempt in range(max_retries):
try:
response = await self._client.put(presigned_url, content=data, headers={"Content-Type": content_type})
response.raise_for_status()
return True
except Exception as e:
if attempt == max_retries - 1:
self.logger.warning(f"Failed to upload to R2 after {max_retries} attempts: {type(e).__name__}: {e}")
return False
delay = 2**attempt
self.logger.debug(f"Retrying R2 upload in {delay}s (attempt {attempt + 1}/{max_retries})")
await asyncio.sleep(delay)

async def _confirm_samples_upload(self, step: int, s3_key: str, sample_count: int, max_retries: int = 3) -> bool:
"""Confirm samples upload with the backend. Returns True on success."""
headers = {"x-api-key": self.api_key, "Content-Type": "application/json"}
for attempt in range(max_retries):
try:
response = await self._client.post(
f"{self.base_url}/samples/confirm",
headers=headers,
json={"run_id": self.run_id, "step": step, "s3_key": s3_key, "sample_count": sample_count},
)
response.raise_for_status()
return True
except Exception as e:
if attempt == max_retries - 1:
self.logger.warning(
f"Failed to confirm samples upload after {max_retries} attempts: {type(e).__name__}: {e}"
)
return False
delay = 2**attempt
self.logger.debug(f"Retrying samples confirm in {delay}s (attempt {attempt + 1}/{max_retries})")
await asyncio.sleep(delay)
return False

def log_final_samples(self) -> None:
"""Log final samples (no-op - samples are logged per-step only)."""
Expand Down Expand Up @@ -263,6 +374,8 @@ def _init_async_client(self) -> None:
self._thread.start()
self._client = httpx.AsyncClient(timeout=30)
self._pending_futures: list[asyncio.Future] = []
if hasattr(self, "_pending_sample_steps") and self._pending_sample_steps:
self._pending_sample_steps.clear()

def _reinit_after_fork(self) -> None:
"""Reinitialize thread and event loop after fork."""
Expand Down