The MultiRunManager object is a global singleton that manages the parameters and components for multiple concurrent training runs within a single trainer process.
This allows multiple orchestrator deployments to share the same trainer.
When max_concurrent_runs > 1, the trainer can train multiple runs in parallel. Each run:
- Has its own LoRA adapter parameters
- Has its own optimizer and scheduler
- Saves its own checkpoints
- Tracks its own training progress (step, tokens, samples)
- Loads its own orchestrator configuration
The MultiRunManager object provides:
- Bidirectional mapping between run IDs (e.g.,
run_abc123) and run indices (0, 1, 2, ...) - Progress tracking per run (step count, total tokens, total samples)
- Configuration management for orchestrator configs
- Distributed synchronization across ranks via the PyTorch distributed store
- LoRA module registration for multi-adapter parameter management
- Creation hooks for initializing per-run resources (optimizers, schedulers)
- Run eviction for removing runs that are misbehaving
The MultiRunManager singleton is set up at the start of training:
from prime_rl.trainer.runs import setup_multi_run_manager, get_multi_run_manager
# Initialize with output directory and max concurrent runs
setup_multi_run_manager(output_dir=Path("outputs/my-experiment"), max_runs=4)
# Get the singleton instance anywhere in the codebase
multi_run_manager = get_multi_run_manager()Each run's directory follows this structure:
{output_dir}/
├── run_abc123/
│ ├── control/
│ │ ├── orch.toml # Orchestrator configuration
│ │ ├── config_validation_error.txt # Config validation errors (if any)
│ │ └── evicted.txt # Eviction reason (if evicted)
│ ├── checkpoints/
│ │ └── step_100/ # Orchestrator checkpoints
│ ├── rollouts/
│ │ └── step_100/ # Rollouts
│ └── broadcast/
│ └── step_100/ # Broadcasted weights for inference
├── run_def456/
│ └── ...
└── ...
Runs are discovered by scanning the output directory for the pattern run_*. Each run must contain a valid orchestrator config at {run_dir}/control/orch.toml before they are added to the active runs otherwise they are ignored. When the maximum number of runs is reached, new run_* directories will not be picked up until old ones are deleted.
# Master rank scans for new/deleted runs
multi_run_manager.discover_runs()
# All ranks synchronize state (must be called after discover_runs)
multi_run_manager.synchronize_state()The discover_runs() method (master only):
- Scans the output directory for
run_*directories - Filters out evicted runs (those with
control/evicted.txt) - Detects new runs and deleted runs
- Calls
forgotten_hookfor deleted runs (master only) - Loads and validates the orchestrator config for each new run
- Updates internal mappings and data structures
- Calls
discovered_hookfor new runs (master only)
The synchronize_state() method (all ranks):
- Master broadcasts run state to all ranks via the distributed store
- Non-master ranks catch up by calling internal
_delete_run_data/_create_run_data - All ranks execute
deletion_hookfor deleted runs - All ranks execute
creation_hookfor new runs (e.g., optimizer setup, LoRA parameter reset)
The master proc on the trainer can evict a run using the evict_run(idx: int, reason: str) method.
This is useful when the trainer detects an issue with a run that requires it to be stopped (e.g., invalid data, resource constraints, or policy violations).
# Evict a run by its index (master only)
multi_run_manager.evict_run(idx=0, reason="Run exceeded memory limits")The evict_run() method (master only):
- Writes the eviction reason to
{run_dir}/control/evicted.txt - Logs a warning with the eviction details
- The run is not immediately removed from the manager's data structures
The eviction takes effect through two mechanisms:
On the trainer side:
- The next
discover_runs()call will filter out the evicted run (it checks forevicted.txt) - The run will then be treated as deleted, triggering forgotten/deletion hooks
- The run index is returned to the unused pool
On the orchestrator side:
- The orchestrator checks for
evicted.txtat the start of each iteration in its main loop - If found, it raises a
RuntimeErrorwith the eviction reason, causing the orchestrator to exit - This surfaces the eviction reason to the user
LoRA modules register themselves with MultiRunManager for parameter management:
# In apply_lora_to_model()
lora_module = MultiLoRALinear(
base_layer=base_module,
rank=config.rank,
n_adapters=get_multi_run_manager().max_runs,
...
)
lora_module.register_with_runs(get_multi_run_manager(), module_name)The MultiRunManager object then exposes:
# Get parameters for a specific run (used by optimizer creation)
multi_run_manager.get_named_parameters_for_run(idx)
# Get state dict for a specific run (used by weight broadcast)
multi_run_manager.get_state_dict_for_run(idx)
# Reset parameters for a new run
multi_run_manager.reset_run_parameters(idx)The MultiRunManager object supports several types of hooks for different lifecycle events.
Deletion hooks are always called before creation hooks.
flowchart TD
subgraph master["Rank 0 (Master)"]
discover["discover_runs()"]
forgotten["forgotten_hooks"]
validation["config_validation_hooks"]
discovered["discovered_hooks"]
discover --> forgotten
forgotten --> validation
validation --> discovered
discovered --> discover
end
subgraph rank1["Rank 1"]
wait1["waiting..."]
end
subgraph rankN["Rank N"]
waitN["waiting..."]
end
discovered --> barrier
wait1 --> barrier
waitN --> barrier
barrier[["synchronize_state()"]]
barrier --> deletion["deletion_hooks"]
deletion --> creation["creation_hooks"]
style barrier fill:#fff9c4
# These hooks are only called on the master as only master uses `discover_runs()`
# These hooks are thus only relevant to master only components (packer)
multi_run_manager.register_discovered_hook(callback)
multi_run_manager.register_forgotten_hook(callback)
# These hooks are executed by all ranks in the order they were added during `synchronize_state()`
# This ensures DTensor creations and other distributed operations happen together
# Calling torch.dist.barrier() in a hook here should work
multi_run_manager.register_creation_hook(callback)
multi_run_manager.register_deletion_hook(callback)
# These hooks validate the orchestrator config when runs are discovered:
multi_run_manager.register_config_validation_hook(callback)The callback signatures:
def discovered_callback(idx: int, run_id: str, config: OrchestratorConfig) -> None:
"""Called when a new run is discovered (master only).
Args:
idx: The run's index (0 to max_runs-1)
run_id: The run's ID (e.g., "run_abc123")
config: The orchestrator config for the run
"""
# Example: Set the scaling factor for the run
multi_run_manager.scaling_factors[idx] = config.model.lora.alpha / config.model.lora.rank
def forgotten_callback(idx: int, run_id: str) -> None:
"""Called when a run is forgotten/removed (master only).
Args:
idx: The run's index (0 to max_runs-1)
run_id: The run's ID (e.g., "run_abc123")
"""
pass
def callback(idx: int, run_id: str) -> None:
"""Called when a run is created/deleted.
Args:
idx: The run's index (0 to max_runs-1)
run_id: The run's ID (e.g., "run_abc123")
"""
pass
def config_validation_callback(config: OrchestratorConfig) -> tuple[bool, str]:
"""Validate an orchestrator config.
Args:
config: The orchestrator config to validate
Returns:
(is_valid, error_message): If invalid, error_message is written to config dir
"""
return True, ""