Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions conf/sec.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
defaults:
- base
- /domain_rollouts@domain_rollouts: base
- _self_

actor:
shared_memory_entry_size: 2000000000
rollout_policy: pipelinerl.domains.dispatcher.generate_multidomain_rollout
# No system prompt - model's chat template provides guidance
system_prompt: ""
# Minimal task template - each problem contains its own instructions
task_template: |-
{task}
task_prompt: ""

domain_rollouts:
math: ${domain_rollouts.math}
coding: ${domain_rollouts.coding}

# domain_mix:
# math: 0.5
# coding: 0.5

# SandboxFusion verification settings
sandbox_endpoint: ${oc.env:SANDBOX_ENDPOINT,http://127.0.0.1:8080}
sandbox_timeout: 10.0
max_tests_per_problem: 5

preprocess:
shared_memory_entry_size: 2000000000

finetune:
seq_length: 32000

vllm_config:
vllm_kwargs:
max_model_len: 32000

llm:
parameters:
max_tokens: 16000

test_llm:
parameters:
max_tokens: 16000

# Bandit-based curriculum learning configuration
curriculum:
# Enable curriculum learning
enabled: true
# How difficulty is determined: "field" (from difficulty_field) or "estimated" (from success rates)
difficulty_source: "field"
# Field name for difficulty/level (used when difficulty_source="field")
difficulty_field: null
# Additional field(s) for categorization beyond difficulty (optional)
# Can be a single string: "dataset" or a list: ["dataset", "type"]
category_fields: ["domain"]
# Softmax temperature (higher = more exploration, lower = more exploitation)
temperature: 0.4
# Q-value update learning rate
learning_rate: 0.5
# Initial Q-value for new categories
initial_q_value: 0.0
# Signal for Q-update: "advantage", "reward", or "success"
update_signal: "advantage"
# Number of difficulty buckets (only used when difficulty_source="estimated")
# Problems are grouped by success rate into buckets:
# e.g., 5 buckets: [0-0.2), [0.2-0.4), [0.4-0.6), [0.6-0.8), [0.8-1.0]
num_difficulty_buckets: 5
# How often to reassign problems to buckets (only used when difficulty_source="estimated")
# Counted in preprocessor batches - each batch updates success rates and may trigger reindex
# Lower = more responsive to changing success rates, higher = more stable bucket assignments
reindex_interval: 5

dataset_loader: pipelinerl.domains.multidomain.loader.load_datasets
dataset_loader_params:
per_domain_params:
coding:
# TACO + APPS
taco_split: train
apps_split: train
subset: train
train_ratio: 0.9
max_tests_per_problem: 5
taco_excluded_difficulties: [VERY_HARD]
skip_apps: false
max_examples: null
seed: 42
huggingface_token: ${oc.env:HF_TOKEN, null}

environments:
- key: math
mode: remote
replicas_per_actor: ${world.env_replicas_per_actor}
_target_: pipelinerl.domains.math.MathEnvironment

environment_key: null

world:
env_replicas_per_actor: 1

train_dataset_names:
- coding::taco
- coding::apps
- math::open_reasoner_zero_57k
- math::open_reasoner_zero_extended_72k

test_dataset_names:
- math::aime_2025
- coding::livecodebench_v5
78 changes: 75 additions & 3 deletions pipelinerl/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from multiprocessing.managers import SharedMemoryManager
from pathlib import Path
from queue import Empty
from typing import Dict, List
from typing import Dict, List, Optional

import aiohttp
import hydra
Expand Down Expand Up @@ -41,6 +41,7 @@
wait_for_environments,
wait_for_inference_servers,
)
from .curriculum import BanditConfig, CurriculumState

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -184,6 +185,17 @@ async def rollout_and_maybe_produce_result(
sample.metadata["model_version"] = model_version
sample.metadata["rollout_index"] = rollout_index
sample.metadata["step_index"] = step_index
# Propagate curriculum metadata if present
if "_selected_category" in problem:
sample.metadata["_selected_category"] = problem["_selected_category"]
if "id" in problem:
sample.metadata["id"] = problem["id"]
# Propagate all configured category fields for stats tracking
if cfg.get("curriculum") and cfg.curriculum.get("enabled"):
curriculum_config = BanditConfig(**cfg.curriculum)
for field in curriculum_config.get_all_category_fields():
if field in problem:
sample.metadata[f"_curriculum_{field}"] = problem[field]
sample.group_id = full_group_id
group_rollouts[group_id].append(rollout_result)
if len(group_rollouts[group_id]) == attempts:
Expand Down Expand Up @@ -308,6 +320,25 @@ def __init__(
self.is_scheduling_paused = False
self.debug_mode = bool(cfg.debug.mode)

# Initialize curriculum learning components if enabled (only for training)
self.curriculum_state: Optional[CurriculumState] = None
self._curriculum_config: Optional[BanditConfig] = None # Keep for validation
if is_training and cfg.get("curriculum") and cfg.curriculum.get("enabled", False):
exp_path = Path(cfg.output_dir)
self._curriculum_config = BanditConfig(**cfg.curriculum)

# Validate: "estimated" difficulty_source requires GRPO-like policy (attempts > 1)
# In estimated mode, per-group success rate is used as difficulty proxy
if self._curriculum_config.difficulty_source == "estimated" and cfg.attempts <= 1:
raise ValueError(
f"Curriculum difficulty_source='estimated' requires attempts > 1 (GRPO-like policy) "
f"to estimate difficulty from per-group success rate. Got attempts={cfg.attempts}"
)

# Feedback stream - listener will be started AFTER forking scheduler processes (fork safety)
self._curriculum_feedback_stream = SingleStreamSpec(exp_path=exp_path, topic="curriculum_feedback")
logger.info(f"Initialized curriculum learning with config: {self._curriculum_config}")

# Determine the number of processes to use
num_processes = min(self.cfg.actor.rollout_workers, len(self.llms))
attempts = self.cfg.attempts if is_training else 1
Expand Down Expand Up @@ -345,12 +376,29 @@ def __init__(
process.start()
self.rollout_processes.append(process)

# Start curriculum feedback listener AFTER forking scheduler processes (fork safety)
# Starting it before fork can cause multiprocessing.Queue corruption
if self._curriculum_config is not None and hasattr(self, '_curriculum_feedback_stream'):
self.curriculum_state = CurriculumState(
self._curriculum_config,
self._curriculum_feedback_stream,
)
self.curriculum_state.start_listening()
logger.info("Started curriculum feedback listener (after fork)")

def init_stats(self):
self.stats = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
self.latency_list = []
self.model_versions_list = []
self.sliding_stats = defaultdict(list)
self.domain_counts = defaultdict(int)

# Curriculum batch-level tracking
self.curriculum_categories = []
# Track numeric values for each feature field (for computing averages)
self.curriculum_feature_values: Dict[str, List[float]] = defaultdict(list)
# Track categorical values for each feature field (for distribution plots)
self.curriculum_feature_categories: Dict[str, List[str]] = defaultdict(list)

def compute_domain_agnostic_metrics(self, result: RolloutResult) -> Dict[str, float]:
metrics = {}
Expand Down Expand Up @@ -398,7 +446,14 @@ def update_stats(self, rollout_results: List[RolloutResult]):
for k, v in sliding_window_stats.items():
self.sliding_stats[k].append(v)


# Track curriculum categories and feature values from rollout metadata
if self.curriculum_state is not None:
cats, feat_vals, feat_cats = self.curriculum_state.track_rollout_results(rollout_results)
self.curriculum_categories.extend(cats)
for field, values in feat_vals.items():
self.curriculum_feature_values[field].extend(values)
for field, values in feat_cats.items():
self.curriculum_feature_categories[field].extend(values)

def run(self, dataset: list[tuple[str, dict]]):
loop_start_time = time.time()
Expand All @@ -418,7 +473,12 @@ def run(self, dataset: list[tuple[str, dict]]):
# for test samples, loop through the dataset once
domain_sampler = None
if self.is_training:
problem_iter = random_iter(dataset)
if self.curriculum_state is not None:
problem_iter = self.curriculum_state.create_iterator(dataset)
logger.info("Using curriculum learning for sampling")
else:
problem_iter = random_iter(dataset)

domain_mix_cfg = getattr(self.cfg.actor, "domain_mix", None)
if domain_mix_cfg:
mix_weights = OmegaConf.to_container(domain_mix_cfg, resolve=True)
Expand Down Expand Up @@ -491,6 +551,8 @@ def run(self, dataset: list[tuple[str, dict]]):
else:
problem = next(problem_iter)
self.problem_queue.put(problem, block=False)
if "_selected_category" in problem:
logger.debug(f"Actor submitting problem with category: {problem['_selected_category']}")
submitted_groups += 1
except queue.Full:
assert False, "Problem queue was not full just a moment ago, but now it is full"
Expand Down Expand Up @@ -622,6 +684,16 @@ def publish_stats(self, stats_writer: StreamWriter, loop_stats: Dict):

for k, v in self.sliding_stats.items():
stats[k] = sum(v) / len(v) if v else 0

# Add curriculum stats if available
if self.curriculum_state is not None:
curriculum_stats = self.curriculum_state.compute_batch_stats(
self.curriculum_categories,
self.curriculum_feature_values,
self.curriculum_feature_categories,
)
stats |= curriculum_stats

if self.cfg.wandb.use_wandb:
wandb.log({f"actor/{k}": v for k, v in stats.items()})
stats_writer.write(stats)
Expand Down
18 changes: 18 additions & 0 deletions pipelinerl/curriculum/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Bandit-based curriculum learning for PipelineRL."""

from .bandit import BanditConfig, BanditState, BanditCurriculum, SuccessRateTracker
from .iterator import BanditIterator
from .feedback import CategoryFeedback, compute_category_feedback, CurriculumFeedbackTracker
from .state import CurriculumState

__all__ = [
"BanditConfig",
"BanditState",
"BanditCurriculum",
"SuccessRateTracker",
"BanditIterator",
"CategoryFeedback",
"compute_category_feedback",
"CurriculumFeedbackTracker",
"CurriculumState",
]
Loading