1515import os
1616import time
1717import warnings
18+ from concurrent .futures import ThreadPoolExecutor
1819from contextlib import nullcontext
1920from pathlib import Path
2021from typing import Any , NotRequired , Optional , TypedDict , TypeVar , cast
@@ -203,6 +204,9 @@ def setup(
203204 Returns:
204205 tuple of policy, cluster, dataloader, tokenizer, loss_fn, math_env, logger, master_config, val_dataloader
205206 """
207+ # Start timing the entire setup process
208+ setup_start_time = time .perf_counter ()
209+
206210 # Extract individual configs for easier access
207211 policy_config = master_config ["policy" ]
208212 generation_config = master_config ["policy" ]["generation" ]
@@ -432,67 +436,128 @@ def setup(
432436 # ==========================
433437 print ("\n ▶ Setting up model and training..." , flush = True )
434438
435- # vllm model loading prefers clean environment, initialize policy_generation before policy (#52 will fix this)
439+ # vllm model loading prefers clean environment, initialize policy_generation before policy in colocated mode
436440 backend = generation_config ["backend" ]
437441 generation_config ["model_name" ] = policy_config ["model_name" ] # Needed for vLLM
438442
443+ # Dictionary to store worker initialization timing stats for logging
444+ worker_init_timing_metrics = {}
445+
446+ # Prepare checkpoint paths
447+ if last_checkpoint_path :
448+ weights_path = Path (last_checkpoint_path ) / "policy" / "weights"
449+ optimizer_path = Path (last_checkpoint_path ) / "policy" / "optimizer"
450+ else :
451+ weights_path = None
452+ optimizer_path = None
453+
454+ if policy_config .get ("megatron_cfg" , {}).get ("enabled" , False ):
455+ ## NOTE: this is equal to the total number of scheduler steps
456+ total_train_iters = min (
457+ grpo_config ["max_num_steps" ],
458+ grpo_config ["max_num_epochs" ] * len (dataloader ),
459+ )
460+ policy_config ["megatron_cfg" ]["train_iters" ] = total_train_iters
461+
462+ # Define initialization functions that will be used in all paths
463+ def init_policy ():
464+ """Initialize policy training workers."""
465+ t0 = time .perf_counter ()
466+ p = Policy (
467+ cluster = train_cluster ,
468+ config = policy_config ,
469+ tokenizer = tokenizer ,
470+ processor = processor ,
471+ weights_path = weights_path ,
472+ optimizer_path = optimizer_path ,
473+ init_optimizer = True ,
474+ )
475+ return p , time .perf_counter () - t0
476+
477+ def init_vllm ():
478+ """Initialize vLLM generation workers."""
479+ t0 = time .perf_counter ()
480+ pg = VllmGeneration (cluster = inference_cluster , config = generation_config )
481+ pg .finish_generation ()
482+ return pg , time .perf_counter () - t0
483+
484+ # Handle backend-specific setup
439485 if backend == "megatron" :
486+ # Megatron backend: policy_generation is None, only initialize policy
440487 policy_generation = None
441488 print (
442489 f" ✓ Using { backend } backend for generation with { policy_config ['model_name' ]} " ,
443490 flush = True ,
444491 )
492+
493+ policy , policy_time = init_policy ()
494+ worker_init_timing_metrics ["policy_init_time_s" ] = policy_time
495+
445496 elif backend == "vllm" :
497+ # vLLM backend: setup config, then decide parallel vs sequential init
446498 generation_config = cast (VllmConfig , generation_config )
447499 if generation_config ["vllm_cfg" ]["precision" ] == "fp8" :
448500 assert loss_config ["use_importance_sampling_correction" ] is True , (
449501 "Importance sampling must be enabled for vLLM FP8 generation for good convergence!"
450502 )
451- ## make vllm hf overrides match the training policy
452503 generation_config ["vllm_cfg" ]["hf_overrides" ] = policy_config .get (
453504 "hf_config_overrides" , {}
454505 )
455506
456- policy_generation = VllmGeneration (
457- cluster = inference_cluster , config = generation_config
458- )
459- # Worker groups are not initialized until the first call to run something on workergroups.
460- # vllm 0.8 fails in initialization if its called in the first training step since it has no clean view of the GPU memory (HF is sharing the same memory).
461- policy_generation .finish_generation ()
507+ # Determine if parallel initialization is possible (non-colocated mode)
508+ use_parallel_init = not colocated_inference
509+
510+ if use_parallel_init :
511+ # Parallel initialization: vLLM and Policy can initialize simultaneously
512+ print (
513+ " ⚡ Using parallel worker initialization (non-colocated mode)" ,
514+ flush = True ,
515+ )
516+
517+ # Execute both initializations in parallel
518+ parallel_start_time = time .perf_counter ()
519+ with ThreadPoolExecutor (max_workers = 2 ) as executor :
520+ vllm_future = executor .submit (init_vllm )
521+ policy_future = executor .submit (init_policy )
522+ policy_generation , vllm_time = vllm_future .result ()
523+ policy , policy_time = policy_future .result ()
524+ parallel_wall_time = time .perf_counter () - parallel_start_time
525+
526+ # Store timing metrics
527+ worker_init_timing_metrics ["vllm_init_time_s" ] = vllm_time
528+ worker_init_timing_metrics ["policy_init_time_s" ] = policy_time
529+ worker_init_timing_metrics ["parallel_wall_time_s" ] = parallel_wall_time
530+ worker_init_timing_metrics ["parallel_init_enabled" ] = True
531+
532+ else :
533+ # Sequential initialization: colocated mode (GPU memory requires vLLM first)
534+ print (
535+ " ⚙️ Using sequential worker initialization (colocated mode)" ,
536+ flush = True ,
537+ )
538+
539+ # Initialize vLLM first (clean GPU memory), then policy
540+ policy_generation , vllm_time = init_vllm ()
541+ worker_init_timing_metrics ["vllm_init_time_s" ] = vllm_time
542+
543+ policy , policy_time = init_policy ()
544+ worker_init_timing_metrics ["policy_init_time_s" ] = policy_time
545+ worker_init_timing_metrics ["parallel_init_enabled" ] = 0.0
546+
462547 print (
463548 f" ✓ Using vLLM backend for generation with { policy_config ['model_name' ]} " ,
464549 flush = True ,
465550 )
466551
467- if last_checkpoint_path :
468- weights_path = Path (last_checkpoint_path ) / "policy" / "weights"
469- optimizer_path = Path (last_checkpoint_path ) / "policy" / "optimizer"
470- else :
471- weights_path = None
472- optimizer_path = None
473-
474- if policy_config .get ("megatron_cfg" , {}).get ("enabled" , False ):
475- ## NOTE: this is equal to the total number of scheduler steps
476- total_train_iters = min (
477- grpo_config ["max_num_steps" ],
478- grpo_config ["max_num_epochs" ] * len (dataloader ),
479- )
480- policy_config ["megatron_cfg" ]["train_iters" ] = total_train_iters
552+ # Record when worker initialization completes (for calculating other setup time)
553+ worker_init_complete_time = time .perf_counter () - setup_start_time
481554
482- policy = Policy (
483- cluster = train_cluster ,
484- config = policy_config ,
485- tokenizer = tokenizer ,
486- processor = processor ,
487- weights_path = weights_path ,
488- optimizer_path = optimizer_path ,
489- init_optimizer = True ,
490- )
491555 # print the node IP and GPU ID of the policy workers for debugging
492556 policy .print_node_ip_and_gpu_id ()
493557
494558 # if it is not colocated inference, initialize collective communication for update weights
495559 if not colocated_inference :
560+ t0 = time .perf_counter ()
496561 ip , port = train_cluster .get_master_address_and_port ()
497562 print (f"Using ip: { ip } , port: { port } for collective communication" , flush = True )
498563 # world includes all training workers and all inference workers
@@ -508,15 +573,45 @@ def setup(
508573 ) # type: ignore
509574 # wait for all futures to complete
510575 ray .get (futures_train + futures_inference )
576+ worker_init_timing_metrics ["collective_init_time_s" ] = time .perf_counter () - t0
511577
512578 # prepare refit info
513579 state_dict_info = policy .prepare_refit_info ()
514580 policy_generation .prepare_refit_info (state_dict_info )
515581
516582 loss_fn = ClippedPGLossFn (loss_config )
517583
584+ # Calculate total setup time
585+ total_setup_time = time .perf_counter () - setup_start_time
586+ worker_init_timing_metrics ["total_setup_time_s" ] = total_setup_time
587+
588+ # Log worker initialization timing metrics to logger
589+ if worker_init_timing_metrics :
590+ print ("\n ▶ Worker Initialization Timing:" )
591+
592+ vllm_time = worker_init_timing_metrics .get ("vllm_init_time_s" , 0 )
593+ policy_time = worker_init_timing_metrics .get ("policy_init_time_s" , 0 )
594+ total_setup = worker_init_timing_metrics .get ("total_setup_time_s" , 0 )
595+
596+ if vllm_time :
597+ print (f" vLLM init: { vllm_time :.1f} s" )
598+
599+ if policy_time :
600+ print (f" Policy init: { policy_time :.1f} s" )
601+
602+ # Calculate "other" time (time after worker init completes)
603+ other_time = total_setup - worker_init_complete_time
604+ worker_init_timing_metrics ["other_setup_time_s" ] = other_time
605+ print (f" Other setup: { other_time :.1f} s" )
606+
607+ print (f" Total setup: { total_setup :.1f} s" )
608+
609+ # Log all metrics to the logger for analysis
610+ logger .log_metrics (worker_init_timing_metrics , step = 0 , prefix = "timing/setup" )
611+
518612 print ("\n " + "=" * 60 )
519613 print (" " * 18 + "SETUP COMPLETE" )
614+ print (f" Total setup time: { total_setup_time :.1f} s" )
520615 print ("=" * 60 + "\n " , flush = True )
521616
522617 return (
0 commit comments