Skip to content

Commit b3a7892

Browse files
authored
feat: improve non-colocated startup by starting policy and vllm in parallel (#1515)
Signed-off-by: Terry Kong <terryk@nvidia.com>
1 parent 7124e44 commit b3a7892

File tree

2 files changed

+160
-35
lines changed

2 files changed

+160
-35
lines changed

nemo_rl/algorithms/grpo.py

Lines changed: 126 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import os
1616
import time
1717
import warnings
18+
from concurrent.futures import ThreadPoolExecutor
1819
from contextlib import nullcontext
1920
from pathlib import Path
2021
from typing import Any, NotRequired, Optional, TypedDict, TypeVar, cast
@@ -203,6 +204,9 @@ def setup(
203204
Returns:
204205
tuple of policy, cluster, dataloader, tokenizer, loss_fn, math_env, logger, master_config, val_dataloader
205206
"""
207+
# Start timing the entire setup process
208+
setup_start_time = time.perf_counter()
209+
206210
# Extract individual configs for easier access
207211
policy_config = master_config["policy"]
208212
generation_config = master_config["policy"]["generation"]
@@ -432,67 +436,128 @@ def setup(
432436
# ==========================
433437
print("\n▶ Setting up model and training...", flush=True)
434438

435-
# vllm model loading prefers clean environment, initialize policy_generation before policy (#52 will fix this)
439+
# vllm model loading prefers clean environment, initialize policy_generation before policy in colocated mode
436440
backend = generation_config["backend"]
437441
generation_config["model_name"] = policy_config["model_name"] # Needed for vLLM
438442

443+
# Dictionary to store worker initialization timing stats for logging
444+
worker_init_timing_metrics = {}
445+
446+
# Prepare checkpoint paths
447+
if last_checkpoint_path:
448+
weights_path = Path(last_checkpoint_path) / "policy" / "weights"
449+
optimizer_path = Path(last_checkpoint_path) / "policy" / "optimizer"
450+
else:
451+
weights_path = None
452+
optimizer_path = None
453+
454+
if policy_config.get("megatron_cfg", {}).get("enabled", False):
455+
## NOTE: this is equal to the total number of scheduler steps
456+
total_train_iters = min(
457+
grpo_config["max_num_steps"],
458+
grpo_config["max_num_epochs"] * len(dataloader),
459+
)
460+
policy_config["megatron_cfg"]["train_iters"] = total_train_iters
461+
462+
# Define initialization functions that will be used in all paths
463+
def init_policy():
464+
"""Initialize policy training workers."""
465+
t0 = time.perf_counter()
466+
p = Policy(
467+
cluster=train_cluster,
468+
config=policy_config,
469+
tokenizer=tokenizer,
470+
processor=processor,
471+
weights_path=weights_path,
472+
optimizer_path=optimizer_path,
473+
init_optimizer=True,
474+
)
475+
return p, time.perf_counter() - t0
476+
477+
def init_vllm():
478+
"""Initialize vLLM generation workers."""
479+
t0 = time.perf_counter()
480+
pg = VllmGeneration(cluster=inference_cluster, config=generation_config)
481+
pg.finish_generation()
482+
return pg, time.perf_counter() - t0
483+
484+
# Handle backend-specific setup
439485
if backend == "megatron":
486+
# Megatron backend: policy_generation is None, only initialize policy
440487
policy_generation = None
441488
print(
442489
f" ✓ Using {backend} backend for generation with {policy_config['model_name']}",
443490
flush=True,
444491
)
492+
493+
policy, policy_time = init_policy()
494+
worker_init_timing_metrics["policy_init_time_s"] = policy_time
495+
445496
elif backend == "vllm":
497+
# vLLM backend: setup config, then decide parallel vs sequential init
446498
generation_config = cast(VllmConfig, generation_config)
447499
if generation_config["vllm_cfg"]["precision"] == "fp8":
448500
assert loss_config["use_importance_sampling_correction"] is True, (
449501
"Importance sampling must be enabled for vLLM FP8 generation for good convergence!"
450502
)
451-
## make vllm hf overrides match the training policy
452503
generation_config["vllm_cfg"]["hf_overrides"] = policy_config.get(
453504
"hf_config_overrides", {}
454505
)
455506

456-
policy_generation = VllmGeneration(
457-
cluster=inference_cluster, config=generation_config
458-
)
459-
# Worker groups are not initialized until the first call to run something on workergroups.
460-
# vllm 0.8 fails in initialization if its called in the first training step since it has no clean view of the GPU memory (HF is sharing the same memory).
461-
policy_generation.finish_generation()
507+
# Determine if parallel initialization is possible (non-colocated mode)
508+
use_parallel_init = not colocated_inference
509+
510+
if use_parallel_init:
511+
# Parallel initialization: vLLM and Policy can initialize simultaneously
512+
print(
513+
" ⚡ Using parallel worker initialization (non-colocated mode)",
514+
flush=True,
515+
)
516+
517+
# Execute both initializations in parallel
518+
parallel_start_time = time.perf_counter()
519+
with ThreadPoolExecutor(max_workers=2) as executor:
520+
vllm_future = executor.submit(init_vllm)
521+
policy_future = executor.submit(init_policy)
522+
policy_generation, vllm_time = vllm_future.result()
523+
policy, policy_time = policy_future.result()
524+
parallel_wall_time = time.perf_counter() - parallel_start_time
525+
526+
# Store timing metrics
527+
worker_init_timing_metrics["vllm_init_time_s"] = vllm_time
528+
worker_init_timing_metrics["policy_init_time_s"] = policy_time
529+
worker_init_timing_metrics["parallel_wall_time_s"] = parallel_wall_time
530+
worker_init_timing_metrics["parallel_init_enabled"] = True
531+
532+
else:
533+
# Sequential initialization: colocated mode (GPU memory requires vLLM first)
534+
print(
535+
" ⚙️ Using sequential worker initialization (colocated mode)",
536+
flush=True,
537+
)
538+
539+
# Initialize vLLM first (clean GPU memory), then policy
540+
policy_generation, vllm_time = init_vllm()
541+
worker_init_timing_metrics["vllm_init_time_s"] = vllm_time
542+
543+
policy, policy_time = init_policy()
544+
worker_init_timing_metrics["policy_init_time_s"] = policy_time
545+
worker_init_timing_metrics["parallel_init_enabled"] = 0.0
546+
462547
print(
463548
f" ✓ Using vLLM backend for generation with {policy_config['model_name']}",
464549
flush=True,
465550
)
466551

467-
if last_checkpoint_path:
468-
weights_path = Path(last_checkpoint_path) / "policy" / "weights"
469-
optimizer_path = Path(last_checkpoint_path) / "policy" / "optimizer"
470-
else:
471-
weights_path = None
472-
optimizer_path = None
473-
474-
if policy_config.get("megatron_cfg", {}).get("enabled", False):
475-
## NOTE: this is equal to the total number of scheduler steps
476-
total_train_iters = min(
477-
grpo_config["max_num_steps"],
478-
grpo_config["max_num_epochs"] * len(dataloader),
479-
)
480-
policy_config["megatron_cfg"]["train_iters"] = total_train_iters
552+
# Record when worker initialization completes (for calculating other setup time)
553+
worker_init_complete_time = time.perf_counter() - setup_start_time
481554

482-
policy = Policy(
483-
cluster=train_cluster,
484-
config=policy_config,
485-
tokenizer=tokenizer,
486-
processor=processor,
487-
weights_path=weights_path,
488-
optimizer_path=optimizer_path,
489-
init_optimizer=True,
490-
)
491555
# print the node IP and GPU ID of the policy workers for debugging
492556
policy.print_node_ip_and_gpu_id()
493557

494558
# if it is not colocated inference, initialize collective communication for update weights
495559
if not colocated_inference:
560+
t0 = time.perf_counter()
496561
ip, port = train_cluster.get_master_address_and_port()
497562
print(f"Using ip: {ip}, port: {port} for collective communication", flush=True)
498563
# world includes all training workers and all inference workers
@@ -508,15 +573,45 @@ def setup(
508573
) # type: ignore
509574
# wait for all futures to complete
510575
ray.get(futures_train + futures_inference)
576+
worker_init_timing_metrics["collective_init_time_s"] = time.perf_counter() - t0
511577

512578
# prepare refit info
513579
state_dict_info = policy.prepare_refit_info()
514580
policy_generation.prepare_refit_info(state_dict_info)
515581

516582
loss_fn = ClippedPGLossFn(loss_config)
517583

584+
# Calculate total setup time
585+
total_setup_time = time.perf_counter() - setup_start_time
586+
worker_init_timing_metrics["total_setup_time_s"] = total_setup_time
587+
588+
# Log worker initialization timing metrics to logger
589+
if worker_init_timing_metrics:
590+
print("\n▶ Worker Initialization Timing:")
591+
592+
vllm_time = worker_init_timing_metrics.get("vllm_init_time_s", 0)
593+
policy_time = worker_init_timing_metrics.get("policy_init_time_s", 0)
594+
total_setup = worker_init_timing_metrics.get("total_setup_time_s", 0)
595+
596+
if vllm_time:
597+
print(f" vLLM init: {vllm_time:.1f}s")
598+
599+
if policy_time:
600+
print(f" Policy init: {policy_time:.1f}s")
601+
602+
# Calculate "other" time (time after worker init completes)
603+
other_time = total_setup - worker_init_complete_time
604+
worker_init_timing_metrics["other_setup_time_s"] = other_time
605+
print(f" Other setup: {other_time:.1f}s")
606+
607+
print(f" Total setup: {total_setup:.1f}s")
608+
609+
# Log all metrics to the logger for analysis
610+
logger.log_metrics(worker_init_timing_metrics, step=0, prefix="timing/setup")
611+
518612
print("\n" + "=" * 60)
519613
print(" " * 18 + "SETUP COMPLETE")
614+
print(f" Total setup time: {total_setup_time:.1f}s")
520615
print("=" * 60 + "\n", flush=True)
521616

522617
return (

nemo_rl/distributed/worker_groups.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@
1313
# limitations under the License.
1414
import importlib
1515
import os
16+
import time
1617
from copy import deepcopy
1718
from dataclasses import dataclass
1819
from typing import Any, Optional, Union
1920

2021
import ray
2122
from ray.util.placement_group import PlacementGroup
2223
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
24+
from tqdm import tqdm
2325

2426
from nemo_rl.distributed.named_sharding import NamedSharding
2527
from nemo_rl.distributed.ray_actor_environment_registry import (
@@ -383,7 +385,8 @@ def __init__(
383385
):
384386
if worker_count > pg.bundle_count:
385387
raise ValueError(
386-
f"Placement group {i} has {pg.bundle_count} bundles, but {worker_count} workers were requested"
388+
f"Placement group {i} has {pg.bundle_count} bundles, "
389+
f"but {worker_count} workers were requested"
387390
)
388391

389392
for bundle_idx in range(worker_count):
@@ -561,12 +564,39 @@ def _create_workers_from_bundle_indices(
561564

562565
global_rank += 1
563566

567+
# Wait for all workers to initialize with timing and progress bar
568+
num_workers = len(worker_futures)
569+
worker_refs = [future for future, _ in worker_futures]
570+
571+
start_time = time.perf_counter()
572+
573+
# Use ray.wait() to track individual worker completion times
574+
remaining_refs = worker_refs.copy()
575+
576+
with tqdm(
577+
total=num_workers,
578+
desc=f"Initializing {self.name_prefix} workers",
579+
unit="worker",
580+
disable=False,
581+
) as pbar:
582+
while remaining_refs:
583+
# Wait for at least one worker to complete
584+
ready_refs, remaining_refs = ray.wait(
585+
remaining_refs, num_returns=1, timeout=None
586+
)
587+
588+
# Update progress bar for each ready worker
589+
for _ in ready_refs:
590+
pbar.update(1)
591+
592+
# Get all worker results
593+
workers = ray.get(worker_refs)
594+
total_init_time = time.perf_counter() - start_time
595+
564596
print(
565-
f"Waiting for {len(worker_futures)} workers to finish initializing...",
597+
f"{num_workers} workers initialized in {total_init_time:.2f}s",
566598
flush=True,
567599
)
568-
worker_refs = [future for future, _ in worker_futures]
569-
workers = ray.get(worker_refs)
570600

571601
for idx, (worker, (_, initializer)) in enumerate(zip(workers, worker_futures)):
572602
worker._RAY_INITIALIZER_ACTOR_REF_TO_AVOID_GC = initializer

0 commit comments

Comments
 (0)