Skip to content

Commit bbbb3d6

Browse files
authored
fix: fix non-colocated with cpu_offload enabled (#861)
Signed-off-by: Yuki Huang <[email protected]>
1 parent 88a399e commit bbbb3d6

File tree

2 files changed

+102
-39
lines changed

2 files changed

+102
-39
lines changed

nemo_rl/models/policy/dtensor_policy_worker.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,8 +1224,11 @@ def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]:
12241224
"""
12251225
from nemo_rl.utils.nvml import get_free_memory_bytes
12261226

1227+
# Manually move model to cuda for cpu offload case
1228+
if self.cpu_offload:
1229+
self.model = self.move_to_cuda(self.model)
1230+
12271231
# Get state_dict
1228-
self.model = self.move_to_cuda(self.model)
12291232
self._held_sharded_state_dict_reference: dict[str, torch.Tensor] = (
12301233
self.model.state_dict()
12311234
)
@@ -1283,13 +1286,27 @@ def get_weights_ipc_handles(self, keys: Iterable[str]) -> dict[str, Any]:
12831286
@torch.no_grad()
12841287
def broadcast_weights_for_collective(self) -> None:
12851288
"""Broadcast the weights for collective communication."""
1289+
# Manually move model to cuda for cpu offload case
1290+
if self.cpu_offload:
1291+
print(
1292+
"[WARNING]: Unless you are lacking of memory, it is not recommended to enable cpu_offload when "
1293+
"using non-colocated generation since it will have an extra onload and offload at refit stage."
1294+
)
1295+
self.model = self.move_to_cuda(self.model)
1296+
1297+
# Broadcast the weights for collective communication
12861298
for _, tensor in self.model.state_dict().items():
12871299
if isinstance(tensor, DTensor):
12881300
tensor = tensor.full_tensor()
12891301
if self.rank == 0:
12901302
tensor = tensor.to(self.dtype, non_blocking=True)
12911303
self.model_update_group.broadcast(tensor.data, src=0)
12921304

1305+
# Manually move model to cpu for cpu offload case
1306+
# cpu offload needs model on CPU before model forward
1307+
if self.cpu_offload:
1308+
self.model = self.move_to_cpu(self.model)
1309+
12931310
def prepare_for_lp_inference(self) -> None:
12941311
if not self.cpu_offload:
12951312
self.move_to_cuda(self.model)
@@ -1308,9 +1325,6 @@ def prepare_for_training(self, *args, **kwargs) -> None:
13081325
# to cuda automatically, so we need to do that manually
13091326
self.model = self.move_buffer_to_device(self.model, "cuda")
13101327

1311-
# have to move buffers to cuda manually for cpu offload case
1312-
self.move_buffer_to_device(self.model, "cuda")
1313-
13141328
self.model.train()
13151329
# Move optimizer state to CUDA if it exists
13161330
if (

tests/unit/models/generation/test_vllm_generation.py

Lines changed: 84 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -623,29 +623,16 @@ def configure_worker_fixed_seed(num_gpus, bundle_indices=None):
623623
torch.cuda.empty_cache()
624624

625625

626-
@pytest.mark.timeout(360)
627-
@pytest.mark.asyncio
628-
@pytest.mark.parametrize("async_engine", [True, False])
629-
async def test_vllm_generation_with_hf_training(cluster, tokenizer, async_engine):
630-
"""1. Use vLLM for generation
631-
2. Use HF policy for training and logprob computation
626+
async def run_hf_train_process(
627+
lm_policy, vllm_policy, tokenizer, async_engine, colocated
628+
):
629+
"""Validates that the two policies can work together.
632630
633-
This test validates that the two policies can work together.
631+
1. Use vLLM for generation
632+
2. Use HF policy for training and logprob computation
634633
"""
635-
from nemo_rl.models.policy.lm_policy import Policy
636634
from tests.unit.test_utils import SimpleNLLLoss
637635

638-
# Create separate configs for each policy
639-
vllm_config = deepcopy(basic_vllm_test_config)
640-
vllm_config["vllm_cfg"]["async_engine"] = async_engine
641-
vllm_config = configure_generation_config(vllm_config, tokenizer)
642-
643-
dtensor_config = deepcopy(basic_dtensor_test_config)
644-
dtensor_config["train_global_batch_size"] = 4
645-
646-
vllm_policy = None
647-
lm_policy = None
648-
649636
try:
650637
prompts = [
651638
"Write a story about a magical forest",
@@ -677,22 +664,8 @@ async def test_vllm_generation_with_hf_training(cluster, tokenizer, async_engine
677664
}
678665
)
679666

680-
# Create both policies
681-
print("Creating vLLM policy...")
682-
vllm_policy = VllmGeneration(cluster, vllm_config)
683-
vllm_policy.finish_generation()
684-
685-
print("Creating DTensor policy...")
686-
lm_policy = Policy(cluster, dtensor_config, tokenizer)
687-
688-
print("preparing refit info...")
689-
state_dict_info = lm_policy.prepare_refit_info()
690-
vllm_policy.prepare_refit_info(state_dict_info)
691-
692667
print("refitting vllm policy...")
693-
refit_policy_generation(
694-
lm_policy, vllm_policy, vllm_config["colocated"]["enabled"]
695-
)
668+
refit_policy_generation(lm_policy, vllm_policy, colocated)
696669

697670
# Step 1: Use vLLM for generation
698671
print("Using vLLM policy for fast generation...")
@@ -794,7 +767,7 @@ async def test_vllm_generation_with_hf_training(cluster, tokenizer, async_engine
794767
print(f"Training loss: {results['loss']}")
795768

796769
lm_policy.finish_training()
797-
lm_policy.offload_after_refit()
770+
refit_policy_generation(lm_policy, vllm_policy, colocated)
798771

799772
# Step 4: Use vLLM for generation again to complete the workflow
800773
print("Using vLLM for generation again...")
@@ -821,6 +794,82 @@ async def test_vllm_generation_with_hf_training(cluster, tokenizer, async_engine
821794
lm_policy.shutdown()
822795

823796

797+
@pytest.mark.timeout(300)
798+
@pytest.mark.asyncio
799+
@pytest.mark.parametrize(
800+
("async_engine", "cpu_offload"), [(True, False), (False, True)]
801+
)
802+
async def test_vllm_generation_with_hf_training_colocated(
803+
cluster, tokenizer, async_engine, cpu_offload
804+
):
805+
"""This test validates that DTensor policy can work together with colocated vLLM policy."""
806+
# Create VllmGeneration Policy
807+
print("Creating vLLM policy...")
808+
vllm_config = deepcopy(basic_vllm_test_config)
809+
vllm_config["vllm_cfg"]["async_engine"] = async_engine
810+
vllm_config = configure_generation_config(vllm_config, tokenizer)
811+
vllm_policy = VllmGeneration(cluster, vllm_config)
812+
vllm_policy.finish_generation()
813+
814+
# Create Policy
815+
print("Creating DTensor policy...")
816+
dtensor_config = deepcopy(basic_dtensor_test_config)
817+
dtensor_config["dtensor_cfg"]["cpu_offload"] = cpu_offload
818+
dtensor_config["train_global_batch_size"] = 4
819+
lm_policy = Policy(cluster, dtensor_config, tokenizer)
820+
821+
# Prepare refit info
822+
print("Preparing refit info...")
823+
state_dict_info = lm_policy.prepare_refit_info()
824+
vllm_policy.prepare_refit_info(state_dict_info)
825+
826+
# Test
827+
await run_hf_train_process(lm_policy, vllm_policy, tokenizer, async_engine, True)
828+
829+
830+
@pytest.mark.timeout(300)
831+
@pytest.mark.asyncio
832+
@pytest.mark.parametrize(
833+
("async_engine", "cpu_offload"), [(True, False), (False, True)]
834+
)
835+
async def test_vllm_generation_with_hf_training_non_colocated(
836+
policy_cluster_separate, tokenizer, async_engine, cpu_offload
837+
):
838+
"""This test validates that DTensor policy can work together with non-colocated vLLM policy."""
839+
generation_cluster_separate = get_generation_cluster_separate(1)
840+
841+
# Create VllmGeneration Policy
842+
print("Creating vLLM policy...")
843+
vllm_config = deepcopy(basic_vllm_test_config)
844+
vllm_config["vllm_cfg"]["async_engine"] = async_engine
845+
vllm_config["colocated"]["enabled"] = False
846+
vllm_config = configure_generation_config(vllm_config, tokenizer)
847+
vllm_policy = VllmGeneration(generation_cluster_separate, vllm_config)
848+
vllm_policy.finish_generation()
849+
850+
# Create Policy
851+
print("Creating DTensor policy...")
852+
dtensor_config = deepcopy(basic_dtensor_test_config)
853+
dtensor_config["generation"]["colocated"]["enabled"] = False
854+
dtensor_config["dtensor_cfg"]["cpu_offload"] = cpu_offload
855+
dtensor_config["train_global_batch_size"] = 4
856+
lm_policy = Policy(policy_cluster_separate, dtensor_config, tokenizer)
857+
858+
# Refit
859+
# initialize collective communication for update weights
860+
ip, port = policy_cluster_separate.get_master_address_and_port()
861+
futures_train = lm_policy.init_collective(ip, port, world_size=2)
862+
futures_inference = vllm_policy.init_collective(ip, port, world_size=2)
863+
ray.get(futures_train + futures_inference)
864+
865+
# prepare refit info
866+
state_dict_info = lm_policy.prepare_refit_info()
867+
vllm_policy.prepare_refit_info(state_dict_info)
868+
869+
# Test
870+
await run_hf_train_process(lm_policy, vllm_policy, tokenizer, async_engine, False)
871+
872+
824873
def test_vllm_policy_tensor_parallel(cluster, tokenizer):
825874
"""Test vLLM policy with tensor parallelism > 1."""
826875
# Configure with tensor_parallel_size=2

0 commit comments

Comments
 (0)