Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/prime_rl/orchestrator/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,14 @@ class OrchestratorConfig(BaseSettings):
# The checkpoint configuration
ckpt: CheckpointConfig | None = None

# Whether to reset inference weights to base model when starting from scratch
reload_weights_on_start: Annotated[
bool,
Field(
description="Whether to reset inference weights to the base model when starting from scratch."
),
] = False

# The validation configuration
val: ValConfig | None = None

Expand Down
11 changes: 8 additions & 3 deletions src/prime_rl/orchestrator/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,14 @@ async def orchestrate(config: OrchestratorConfig):
lora_name=config.model.lora.name if config.model.lora else None,
)
else:
logger.info("Training from scratch. Resetting weights to base model")
if config.model.lora is None:
await reload_weights(admin_clients)
if config.reload_weights_on_start:
if config.model.lora is None:
logger.info("Training from scratch. Resetting weights to base model")
await reload_weights(admin_clients)
else:
logger.info("Training from scratch. Skipping base weight reload because LoRA is enabled")
else:
logger.info("Training from scratch. Skipping base weight reload")

# Iterate over dataset in batches
max_steps = config.max_steps or int(1e9)
Expand Down