diff --git a/examples/llm-api/rl_integration_test_async.py b/examples/llm-api/rl_integration_test_async.py new file mode 100644 index 00000000000..e2608e8e02c --- /dev/null +++ b/examples/llm-api/rl_integration_test_async.py @@ -0,0 +1,647 @@ +import argparse +import torch +import pynvml +import contextlib +import torch.distributed as dist +import atexit +import os +import asyncio +from typing import Any, Optional, Generator + +from tensorrt_llm import SamplingParams +from tensorrt_llm import AsyncLLM +from tensorrt_llm.llmapi.llm_args import KvCacheConfig +from torch.distributed.fsdp import ( + FullyShardedDataParallel as FSDP, + StateDictType, + MixedPrecision, + ShardedStateDictConfig, + FullStateDictConfig +) +#from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType +from transformers import AutoModelForCausalLM, AutoTokenizer +from torch.distributed.tensor import DTensor +import torch.multiprocessing as mp +from tensorrt_llm._utils import get_free_port + +def init_distributed(): + """Initialize distributed training""" + if "LOCAL_RANK" not in os.environ: + return 1, 0, torch.device("cuda:0") + + # Set default environment variables if not already set + if "MASTER_ADDR" not in os.environ: + os.environ["MASTER_ADDR"] = "localhost" + if "MASTER_PORT" not in os.environ: + os.environ["MASTER_PORT"] = "29500" + + dist.init_process_group(backend="cpu:gloo,cuda:nccl") + world_size = dist.get_world_size() + rank = dist.get_rank() + torch.cuda.set_device(rank) + + return world_size, rank, torch.device(f"cuda:{rank}") + +def exit_distributed(): + """Exit distributed training""" + if dist.is_initialized(): + dist.destroy_process_group() + +def report_device_id() -> str: + """Report the UUID of the current CUDA device using NVML. + Returns: + str: UUID of the device in the format "GPU-xxxxx" + """ + from tensorrt_llm._torch.utils import get_device_uuid + # Get current device index from torch + device_idx = torch.cuda.current_device() + # Get device UUID using NVML + uuid = get_device_uuid(device_idx) + print(f"fsdp: id: {device_idx}, uuid: {uuid}") + return uuid + +@contextlib.contextmanager +def nvml_context() -> Generator[None, None, None]: + """Context manager for NVML initialization and shutdown. + + Raises: + RuntimeError: If NVML initialization fails + """ + try: + pynvml.nvmlInit() + yield + except pynvml.NVMLError as e: + raise RuntimeError(f"Failed to initialize NVML: {e}") + finally: + try: + pynvml.nvmlShutdown() + except: + pass + +def device_id_to_physical_device_id(device_id: int) -> int: + """Convert a logical device ID to a physical device ID considering CUDA_VISIBLE_DEVICES.""" + import os + if "CUDA_VISIBLE_DEVICES" in os.environ: + device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",") + try: + physical_device_id = int(device_ids[device_id]) + return physical_device_id + except ValueError: + raise RuntimeError( + f"Failed to convert logical device ID {device_id} to physical device ID. Available devices are: {device_ids}." + ) + else: + return device_id + +def get_free_memory_bytes(device_idx: int) -> float: + """Get the free memory of a CUDA device in bytes using NVML.""" + global_device_idx = device_id_to_physical_device_id(device_idx) + with nvml_context(): + try: + handle = pynvml.nvmlDeviceGetHandleByIndex(global_device_idx) + return pynvml.nvmlDeviceGetMemoryInfo(handle).free + except pynvml.NVMLError as e: + raise RuntimeError( + f"Failed to get free memory for device {device_idx} (global index: {global_device_idx}): {e}" + ) + +class fsdp_interface: + def __init__(self, model_dir): + self.model_dir = model_dir + self.world_size = dist.get_world_size() + self.rank = dist.get_rank() + self.device = torch.device(f"cuda:{self.rank}") + self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) + self.model = self.load_fsdp_model(model_dir) + + def load_fsdp_model(self, model_dir): + """Load and initialize FSDP model""" + # Initialize distributed setup + print(f"World size: {self.world_size}, Rank: {self.rank}, Device: {self.device}") + + # Setup mixed precision policy for FSDP + mixed_precision_policy = MixedPrecision( + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + buffer_dtype=torch.float32 + ) + + if self.rank == 0: + print(f"Loading FSDP model from {model_dir}") + + # Initialize FSDP model + fsdp_model = AutoModelForCausalLM.from_pretrained( + model_dir, + trust_remote_code=True, + torch_dtype=torch.bfloat16, + device_map=self.device + ) + + # Print model info + if self.rank == 0: + total_params = sum(p.numel() for p in fsdp_model.parameters()) + trainable_params = sum(p.numel() for p in fsdp_model.parameters() if p.requires_grad) + print(f"Total parameters: {total_params:,}") + print(f"Trainable parameters: {trainable_params:,}") + print(f"Model device: {next(fsdp_model.parameters()).device}") + + # Wrap model with FSDP + fsdp_model = FSDP( + fsdp_model, + mixed_precision=mixed_precision_policy, + device_id=torch.cuda.current_device(), + use_orig_params=True + ) + + if self.rank == 0: + print("FSDP model initialized successfully") + + self._held_streamed_param_reference = None + self._held_sharded_state_dict_reference = None + + return fsdp_model + + + + def per_tensor_generator(self): + # If the model is not FSDP, then we need to manually move it to the GPU + # For an FSDP model, model.state_dict() will move the params to the GPU + if not isinstance(self.model, FSDP): + self.model = self.manual_load_to_gpu(self.model) + self._held_sharded_state_dict_reference = self.model.state_dict() + else: + # Get sharded state dict instead of full state dict for FSDP1 + with FSDP.state_dict_type( + self.model, + state_dict_type=StateDictType.FULL_STATE_DICT, + state_dict_config=FullStateDictConfig() + ): + self._held_sharded_state_dict_reference = self.model.state_dict() + for name, param in self._held_sharded_state_dict_reference.items(): + yield name, param + + @torch.no_grad() + def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: + # If the model is not FSDP, then we need to manually move it to the GPU + # For an FSDP model, model.state_dict() will move the params to the GPU + if not isinstance(self.model, FSDP): + self.model = self.manual_load_to_gpu(self.model) + self._held_sharded_state_dict_reference = self.model.state_dict() + else: + # Get sharded state dict instead of full state dict for FSDP1 + with FSDP.state_dict_type( + self.model, + state_dict_type=StateDictType.FULL_STATE_DICT, + state_dict_config=FullStateDictConfig() + ): + self._held_sharded_state_dict_reference = self.model.state_dict() + + # Collect info for streaming multiple tensors + ### state_dict_info = [] + ### for name, tensor in self._held_sharded_state_dict_reference.items(): + ### # dtensor's numel will return complete tensor instead of only local tensor + ### size_in_bytes = tensor.element_size() * tensor.numel() + ### state_dict_info.append((name, size_in_bytes)) + self.refit_param_info = [] + for name, tensor in self._held_sharded_state_dict_reference.items(): + # dtensor's numel will return complete tensor instead of only local tensor + size_in_bytes = tensor.element_size() * tensor.numel() + self.refit_param_info.append((name, size_in_bytes)) + + #print(f"State dict info: {state_dict_info}") + # Collect current available memory for refit + ## Get current device index from torch + device_idx = torch.cuda.current_device() + ## Get device free memory using NVML + total_available_bytes = get_free_memory_bytes(device_idx) + ## Use 80% of the free memory for safety + memory_ratio = os.getenv("NRL_REFIT_BUFFER_MEMORY_RATIO", "0.8") + total_available_bytes *= float(memory_ratio) + + return self.refit_param_info, total_available_bytes + + @torch.no_grad() + def get_weights_ipc_handles(self, keys: list[str]) -> dict[str, Any]: + from torch.distributed.tensor import DTensor + from torch.multiprocessing.reductions import reduce_tensor + + assert self._held_sharded_state_dict_reference is not None, ( + "prepare_weights_for_ipc must be called before get_weights_ipc_handles" + ) + + # Clean up the held tensors to reduce peak memory + if self._held_streamed_param_reference is not None: + del self._held_streamed_param_reference + self._held_streamed_param_reference = None + + converted_params = {} + for key in keys: + # Get full_tensor for dtensor (GPU > 1) + if not key.startswith("model."): + continue + tensor = self._held_sharded_state_dict_reference[key] + if isinstance(tensor, DTensor): + full_tensor = tensor.full_tensor() + else: + full_tensor = tensor + # Convert parameters to the configured dtype + #print(f"FSDP rank {self.rank} name: {key}, shape: {full_tensor.shape}, {full_tensor[0]}") + converted_params[key] = full_tensor + + # Temporary record the full tensor for cleanup + # It is needed for cleanup the last full_tensor in the refit process + self._held_streamed_param_reference = converted_params + + # Get device UUID for IPC + device_uuid = report_device_id() + # Create handles for the tensors + all_handles = [] + for key, p in converted_params.items(): + handle = reduce_tensor(p.detach()) + all_handles.append((key, handle)) + + #print(f"device_uuid: {device_uuid}, All handles keys: {[key for key, _ in all_handles]}") + print(f"device_uuid: {device_uuid}") + return {device_uuid: all_handles} + + @torch.no_grad() + def prepare_weights_for_ipc_refit( + self, _refit_buffer_size_gb: Optional[int] = None + ) -> list[list[str]]: + """Prepare the weights for IPC. + + Returns: + list: A list containing the keys of the parameters, which is grouped by size. + """ + # Get the state_dict_info and available memory from all workers + state_dict_info = self.refit_param_info + + if _refit_buffer_size_gb is not None: + total_available_bytes = _refit_buffer_size_gb * (1024**3) + else: + # Get the minimum available memory from all workers + total_available_bytes = min(result[1] for result in state_dict_info) + + # Group tensors by size + cur_available_bytes = total_available_bytes + grouped_param_keys: list[list[str]] = [] + keys: list[str] = [] + + for key, size_in_bytes in state_dict_info: + if size_in_bytes > cur_available_bytes: + if keys: + grouped_param_keys.append(keys) + keys = [] + cur_available_bytes = total_available_bytes + + keys.append(key) + cur_available_bytes -= size_in_bytes + + if keys: + grouped_param_keys.append(keys) + + return grouped_param_keys + +class NamedParam: + def __init__(self, name, size, param): + self.name = name + self.size = size + self.param = param + +class GateAndUp: + def __init__(self): + self.gate = None + self.up = None + def set_gate(self, gate): + self.gate = gate + def set_up(self, up): + self.up = up + def get_size(self): + return self.gate.size + self.up.size + def is_complete(self): + return self.gate is not None and self.up is not None + +class trtllm_interface: + def __init__(self, model_dir, tensor_parallel_size): + self.world_size = dist.get_world_size() + self.rank = dist.get_rank() + self.device = torch.device(f"cuda:{self.rank}") + self.model_dir = model_dir + self.tensor_parallel_size = tensor_parallel_size + + async def init_trtllm(self): + self.llm = await self.load_trtllm_model(self.model_dir, self.tensor_parallel_size) + + async def load_trtllm_model(self, model_dir, tensor_parallel_size): + if self.rank == 0: + print(f"Loading TensorRT-LLM model: {model_dir}, tensor_parallel_size: {tensor_parallel_size}") + # Save and clear distributed environment variables to avoid conflicts + # Ray orchestrator will set up its own process group in separate actors + saved_env = {} + env_vars_to_clear = ['LOCAL_RANK', 'RANK', 'WORLD_SIZE', 'LOCAL_WORLD_SIZE'] + for var in env_vars_to_clear: + if var in os.environ: + saved_env[var] = os.environ[var] + del os.environ[var] + + try: + llm = AsyncLLM( + model=model_dir, + tensor_parallel_size=tensor_parallel_size, + orchestrator_type='ray', + ray_worker_extension_cls='tensorrt_llm.llmapi.rlhf_utils.WorkerExtension', + load_format='dummy', + #enable_sleep=True, # crash + kv_cache_config=KvCacheConfig( + free_gpu_memory_fraction=0.85, + enable_block_reuse=False + ) + ) + await llm.async_init_phase() + finally: + # Restore environment variables + for var, value in saved_env.items(): + os.environ[var] = value + + return llm + else: + return None + + def update_weights_from_ipc_handles(self, rank, device_handles): + if rank == 0: + gathered_handles = [None for _ in range(dist.get_world_size())] + else: + gathered_handles = None + dist.gather_object( + obj=device_handles, + object_gather_list=gathered_handles, + dst=0 + ) + if rank == 0: + all_handles = {k: v for d in gathered_handles for k, v in d.items()} + result = self.llm._collective_rpc('update_weights', (all_handles, )) + return result + else: + return None + + def update_weights_from_tensor_generator(self, tensor_generator): + device_uuid = report_device_id() + rank = dist.get_rank() + from torch.multiprocessing.reductions import reduce_tensor + total_available_bytes = 0.7 * (1024**3) + cur_available_bytes = total_available_bytes + converted_params = {} + cur_handles = [] + gate_up = {} + stream_step = 0 + for name, param in tensor_generator: + size_in_bytes = param.element_size() * param.numel() + if isinstance(param, DTensor): + param = param.full_tensor() + gate_up_name = None + gate_up_pair = None + if "gate_proj" in name: + gate_up_name = name.replace("gate_proj", "") + if (gate_up_name not in gate_up): + gate_up[gate_up_name] = GateAndUp() + assert gate_up[gate_up_name].gate is None + gate_up[gate_up_name].set_gate(NamedParam(name, size_in_bytes, param)) + elif "up_proj" in name: + gate_up_name = name.replace("up_proj", "") + if (gate_up_name not in gate_up): + gate_up[gate_up_name] = GateAndUp() + assert gate_up[gate_up_name].up is None + gate_up[gate_up_name].set_up(NamedParam(name, size_in_bytes, param)) + if (gate_up_name is not None): + if gate_up[gate_up_name].is_complete(): + gate_up_pair = gate_up.pop(gate_up_name) + size_in_bytes = gate_up_pair.get_size() + else: + continue + + if size_in_bytes > cur_available_bytes: + stream_step += 1 + device_handles = {device_uuid: cur_handles} + print(f"stream_step: {stream_step}") + result = self.update_weights_from_ipc_handles(rank, device_handles) + print(f"update_weights_from_ipc_handles result: {result}") + cur_available_bytes = total_available_bytes + del converted_params + converted_params = {} + cur_handles = [] + + assert cur_available_bytes >= size_in_bytes + cur_available_bytes -= size_in_bytes + if (gate_up_pair is not None): + converted_params[gate_up_pair.gate.name] = gate_up_pair.gate.param + converted_params[gate_up_pair.up.name] = gate_up_pair.up.param + handle = reduce_tensor(gate_up_pair.gate.param.detach()) + cur_handles.append((gate_up_pair.gate.name, handle)) + handle = reduce_tensor(gate_up_pair.up.param.detach()) + cur_handles.append((gate_up_pair.up.name, handle)) + gate_up_pair = None + else: + converted_params[name] = param + handle = reduce_tensor(param.detach()) + cur_handles.append((name, handle)) + + assert len(gate_up) == 0 + + if cur_handles: + device_handles = {device_uuid: cur_handles} + stream_step += 1 + print(f"stream_step: {stream_step}") + result = self.update_weights_from_ipc_handles(rank, device_handles) + print(f"update_weights_from_ipc_handles result: {result}") + cur_available_bytes = total_available_bytes + del converted_params + converted_params = {} + cur_handles = [] + +def get_current_process_memory_info() -> int: + """ + Returns GPU memory usage for current process in bytes. + """ + # Get current process ID + current_pid = os.getpid() + # Get device handle for GPU 0 + device_handle = pynvml.nvmlDeviceGetHandleByIndex(0) + + # Get running processes + processes = pynvml.nvmlDeviceGetComputeRunningProcesses(device_handle) + + # Find current process + for process in processes: + if process.pid == current_pid: + return process.usedGpuMemory + + return 0 + +def get_current_mem_info(message: str = ""): + import nvsmi + mem_allocated = torch.cuda.memory_allocated() + mem_reserved = torch.cuda.memory_reserved() + mem_free, mem_total = torch.cuda.mem_get_info() + process_mem_info = get_current_process_memory_info() + print(f"{message} mem_free: {mem_free:,}, mem_total: {mem_total:,}, mem_allocated: {mem_allocated:,}, mem_reserved: {mem_reserved:,}, process_mem_info: {process_mem_info:,}") + for gpu in nvsmi.get_gpus(): + print(gpu) + return mem_free, mem_total, mem_allocated, mem_reserved, process_mem_info + +def get_total_available_bytes(pg: dist.ProcessGroup, message: str = "") -> int: + mem_allocated = torch.cuda.memory_allocated() + mem_reserved = torch.cuda.memory_reserved() + mem_free, mem_total = torch.cuda.mem_get_info() + print(f"{message} mem_free: {mem_free:,}, mem_total: {mem_total:,}, mem_allocated: {mem_allocated:,}, mem_reserved: {mem_reserved:,}") + mem_free = torch.tensor(mem_free) + dist.all_reduce(mem_free, op=dist.ReduceOp.MIN, group=pg) + mem_free = mem_free.item() + print(f"{message} gathered_mem_free: {mem_free:,}") + return mem_free * 0.2 + +def cleanup(): + """Cleanup function to destroy process group""" + if dist.is_initialized(): + print(f"Cleaning up process group on rank {dist.get_rank()}") + dist.destroy_process_group() + +async def async_worker(rank, world_size, model_dir, tensor_parallel_size, use_fsdp): + #os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7" + #os.environ["TRTLLM_RAY_BUNDLE_INDICES"] = "1,2,3,4,5,6,7" + os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" + os.environ["TRTLLM_RAY_BUNDLE_INDICES"] = "1,2" + #os.environ["TRTLLM_RAY_PER_WORKER_GPUS"] = "1" + + """Async worker function that runs the actual test logic within an event loop.""" + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + tags = ["sampler", + "drafter", + "guided_decoder", + "spec_resource_manager", + "_no_capture_model_extra", + "executor_extra", + "kv_cache", + "model", + "draft_model"] + + world_size, rank, device = init_distributed() + + sampling_params = SamplingParams(max_tokens=32) + + # Load FSDP model + fsdp = fsdp_interface(model_dir) + trtllm = trtllm_interface(model_dir, tensor_parallel_size) + await trtllm.init_trtllm() + + if rank == 0: + print(f"Collected handles from all {world_size} ranks:") + + # For FSDP mode, we would need additional logic to integrate withTensorRT-LLM + # This is a placeholder for now + if rank == 0: + for prompt in prompts: + outputs = await trtllm.llm.generate_async(prompt, sampling_params) + generated_text = outputs.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + ## load the model from fsdp + ## then generate the output again + ## get_current_mem_info("Before sleep") + ## result = trtllm.llm._collective_rpc('sleep', args=(tags,)) + ## print(f"sleep result: {result}") + ## get_current_mem_info("After sleep") +## + ## result = trtllm.llm._collective_rpc('wakeup', args=(tags,)) + ## print(f"wakeup result: {result}") + ## get_current_mem_info("After wakeup") + + trtllm.update_weights_from_tensor_generator(fsdp.per_tensor_generator()) + + # generate the output again + if rank == 0: + for prompt in prompts: + outputs = await trtllm.llm.generate_async(prompt, sampling_params) + generated_text = outputs.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + ## load the model from fsdp + ## then generate the output again + ## get_current_mem_info("Before sleep") + ## result = trtllm.llm._collective_rpc('sleep', args=(tags,)) + ## print(f"sleep result: {result}") + ## get_current_mem_info("After sleep") +## + ## result = trtllm.llm._collective_rpc('wakeup', args=(tags,)) + ## print(f"wakeup result: {result}") + ## get_current_mem_info("After wakeup") + + + ##trtllm.update_weights_from_tensor_generator(fsdp.per_tensor_generator()) +## + ### generate the output again + ##if rank == 0: + ## outputs = trtllm.llm.generate(prompts, sampling_params) + ## for i, output in enumerate(outputs): + ## prompt = output.prompt + ## generated_text = output.outputs[0].text + ## print(f"[{i}] Prompt: {prompt!r}, Generated text: {generated_text!r}") +## + exit_distributed() + +def worker(rank, world_size, master_port, model_dir, tensor_parallel_size, use_fsdp): + """Worker process entry point that sets up environment and runs async event loop.""" + # Set up environment variables for distributed training + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(master_port) + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["LOCAL_RANK"] = str(rank) + + # Create a new event loop for this process + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + # Run the async worker function + loop.run_until_complete( + async_worker(rank, world_size, model_dir, tensor_parallel_size, use_fsdp) + ) + finally: + # Clean up the event loop + loop.close() + +def main(): + parser = argparse.ArgumentParser( + description="LLM models with the PyTorch workflow.") + + parser.add_argument('--model_dir', + type=str, + required=True, + default='/model/Qwen2.5-0.5B-Instruct', + help="Model checkpoint directory.") + + parser.add_argument('--tensor_parallel_size', + type=int, + default=2, + help="Tensor parallel size (number of GPUs to use)") + + parser.add_argument('--use_fsdp', + action='store_true', + help="Use FSDP model loading instead of direct TensorRT-LLM loading") + + args = parser.parse_args() + + world_size = args.tensor_parallel_size + master_port = get_free_port() + mp.spawn(worker, args=(world_size, master_port, args.model_dir, args.tensor_parallel_size, args.use_fsdp), nprocs=world_size, join=True) + +if __name__ == '__main__': + main() + +#python3 examples/llm-api/rl_integration_test_async.py --model_dir /model/Qwen2.5-0.5B-Instruct --tensor_parallel_size 2 \ No newline at end of file diff --git a/examples/llm-api/rl_integration_test_async_pg.py b/examples/llm-api/rl_integration_test_async_pg.py new file mode 100644 index 00000000000..4a8f1306316 --- /dev/null +++ b/examples/llm-api/rl_integration_test_async_pg.py @@ -0,0 +1,722 @@ +""" +RL Integration Test with Placement Group Support for AsyncLLM + +This script demonstrates how to use TensorRT-LLM AsyncLLM with Ray Placement Groups +for resource management in RLHF training scenarios. + +Usage: + python rl_integration_test_async_pg.py --model_dir /path/to/model --tensor_parallel_size 2 + +Features: + - Automatic Ray Placement Group creation and management + - Direct placement_where configuration + - Compatible with FSDP weight updates + - Multi-process distributed execution +""" + +import argparse +import torch +import pynvml +import contextlib +import torch.distributed as dist +import atexit +import os +import asyncio +from typing import Any, Optional, Generator + +from tensorrt_llm import SamplingParams +from tensorrt_llm import AsyncLLM +from tensorrt_llm.llmapi.llm_args import KvCacheConfig +from torch.distributed.fsdp import ( + FullyShardedDataParallel as FSDP, + StateDictType, + MixedPrecision, + ShardedStateDictConfig, + FullStateDictConfig +) +#from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType +from transformers import AutoModelForCausalLM, AutoTokenizer +from torch.distributed.tensor import DTensor +import torch.multiprocessing as mp +from tensorrt_llm._utils import get_free_port + +def init_distributed(): + """Initialize distributed training""" + if "LOCAL_RANK" not in os.environ: + return 1, 0, torch.device("cuda:0") + + # Set default environment variables if not already set + if "MASTER_ADDR" not in os.environ: + os.environ["MASTER_ADDR"] = "localhost" + if "MASTER_PORT" not in os.environ: + os.environ["MASTER_PORT"] = "29500" + + dist.init_process_group(backend="cpu:gloo,cuda:nccl") + world_size = dist.get_world_size() + rank = dist.get_rank() + torch.cuda.set_device(rank) + + return world_size, rank, torch.device(f"cuda:{rank}") + +def exit_distributed(): + """Exit distributed training""" + if dist.is_initialized(): + dist.destroy_process_group() + +def report_device_id() -> str: + """Report the UUID of the current CUDA device using NVML. + Returns: + str: UUID of the device in the format "GPU-xxxxx" + """ + from tensorrt_llm._torch.utils import get_device_uuid + # Get current device index from torch + device_idx = torch.cuda.current_device() + # Get device UUID using NVML + uuid = get_device_uuid(device_idx) + print(f"fsdp: id: {device_idx}, uuid: {uuid}") + return uuid + +@contextlib.contextmanager +def nvml_context() -> Generator[None, None, None]: + """Context manager for NVML initialization and shutdown. + + Raises: + RuntimeError: If NVML initialization fails + """ + try: + pynvml.nvmlInit() + yield + except pynvml.NVMLError as e: + raise RuntimeError(f"Failed to initialize NVML: {e}") + finally: + try: + pynvml.nvmlShutdown() + except: + pass + +def device_id_to_physical_device_id(device_id: int) -> int: + """Convert a logical device ID to a physical device ID considering CUDA_VISIBLE_DEVICES.""" + import os + if "CUDA_VISIBLE_DEVICES" in os.environ: + device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",") + try: + physical_device_id = int(device_ids[device_id]) + return physical_device_id + except ValueError: + raise RuntimeError( + f"Failed to convert logical device ID {device_id} to physical device ID. Available devices are: {device_ids}." + ) + else: + return device_id + +def get_free_memory_bytes(device_idx: int) -> float: + """Get the free memory of a CUDA device in bytes using NVML.""" + global_device_idx = device_id_to_physical_device_id(device_idx) + with nvml_context(): + try: + handle = pynvml.nvmlDeviceGetHandleByIndex(global_device_idx) + return pynvml.nvmlDeviceGetMemoryInfo(handle).free + except pynvml.NVMLError as e: + raise RuntimeError( + f"Failed to get free memory for device {device_idx} (global index: {global_device_idx}): {e}" + ) + +class fsdp_interface: + def __init__(self, model_dir): + self.model_dir = model_dir + self.world_size = dist.get_world_size() + self.rank = dist.get_rank() + self.device = torch.device(f"cuda:{self.rank}") + self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) + self.model = self.load_fsdp_model(model_dir) + + def load_fsdp_model(self, model_dir): + """Load and initialize FSDP model""" + # Initialize distributed setup + print(f"World size: {self.world_size}, Rank: {self.rank}, Device: {self.device}") + + # Setup mixed precision policy for FSDP + mixed_precision_policy = MixedPrecision( + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + buffer_dtype=torch.float32 + ) + + if self.rank == 0: + print(f"Loading FSDP model from {model_dir}") + + # Initialize FSDP model + fsdp_model = AutoModelForCausalLM.from_pretrained( + model_dir, + trust_remote_code=True, + torch_dtype=torch.bfloat16, + device_map=self.device + ) + + # Print model info + if self.rank == 0: + total_params = sum(p.numel() for p in fsdp_model.parameters()) + trainable_params = sum(p.numel() for p in fsdp_model.parameters() if p.requires_grad) + print(f"Total parameters: {total_params:,}") + print(f"Trainable parameters: {trainable_params:,}") + print(f"Model device: {next(fsdp_model.parameters()).device}") + + # Wrap model with FSDP + fsdp_model = FSDP( + fsdp_model, + mixed_precision=mixed_precision_policy, + device_id=torch.cuda.current_device(), + use_orig_params=True + ) + + if self.rank == 0: + print("FSDP model initialized successfully") + + self._held_streamed_param_reference = None + self._held_sharded_state_dict_reference = None + + return fsdp_model + + + + def per_tensor_generator(self): + # If the model is not FSDP, then we need to manually move it to the GPU + # For an FSDP model, model.state_dict() will move the params to the GPU + if not isinstance(self.model, FSDP): + self.model = self.manual_load_to_gpu(self.model) + self._held_sharded_state_dict_reference = self.model.state_dict() + else: + # Get sharded state dict instead of full state dict for FSDP1 + with FSDP.state_dict_type( + self.model, + state_dict_type=StateDictType.FULL_STATE_DICT, + state_dict_config=FullStateDictConfig() + ): + self._held_sharded_state_dict_reference = self.model.state_dict() + for name, param in self._held_sharded_state_dict_reference.items(): + yield name, param + + @torch.no_grad() + def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: + # If the model is not FSDP, then we need to manually move it to the GPU + # For an FSDP model, model.state_dict() will move the params to the GPU + if not isinstance(self.model, FSDP): + self.model = self.manual_load_to_gpu(self.model) + self._held_sharded_state_dict_reference = self.model.state_dict() + else: + # Get sharded state dict instead of full state dict for FSDP1 + with FSDP.state_dict_type( + self.model, + state_dict_type=StateDictType.FULL_STATE_DICT, + state_dict_config=FullStateDictConfig() + ): + self._held_sharded_state_dict_reference = self.model.state_dict() + + # Collect info for streaming multiple tensors + ### state_dict_info = [] + ### for name, tensor in self._held_sharded_state_dict_reference.items(): + ### # dtensor's numel will return complete tensor instead of only local tensor + ### size_in_bytes = tensor.element_size() * tensor.numel() + ### state_dict_info.append((name, size_in_bytes)) + self.refit_param_info = [] + for name, tensor in self._held_sharded_state_dict_reference.items(): + # dtensor's numel will return complete tensor instead of only local tensor + size_in_bytes = tensor.element_size() * tensor.numel() + self.refit_param_info.append((name, size_in_bytes)) + + #print(f"State dict info: {state_dict_info}") + # Collect current available memory for refit + ## Get current device index from torch + device_idx = torch.cuda.current_device() + ## Get device free memory using NVML + total_available_bytes = get_free_memory_bytes(device_idx) + ## Use 80% of the free memory for safety + memory_ratio = os.getenv("NRL_REFIT_BUFFER_MEMORY_RATIO", "0.8") + total_available_bytes *= float(memory_ratio) + + return self.refit_param_info, total_available_bytes + + @torch.no_grad() + def get_weights_ipc_handles(self, keys: list[str]) -> dict[str, Any]: + from torch.distributed.tensor import DTensor + from torch.multiprocessing.reductions import reduce_tensor + + assert self._held_sharded_state_dict_reference is not None, ( + "prepare_weights_for_ipc must be called before get_weights_ipc_handles" + ) + + # Clean up the held tensors to reduce peak memory + if self._held_streamed_param_reference is not None: + del self._held_streamed_param_reference + self._held_streamed_param_reference = None + + converted_params = {} + for key in keys: + # Get full_tensor for dtensor (GPU > 1) + if not key.startswith("model."): + continue + tensor = self._held_sharded_state_dict_reference[key] + if isinstance(tensor, DTensor): + full_tensor = tensor.full_tensor() + else: + full_tensor = tensor + # Convert parameters to the configured dtype + #print(f"FSDP rank {self.rank} name: {key}, shape: {full_tensor.shape}, {full_tensor[0]}") + converted_params[key] = full_tensor + + # Temporary record the full tensor for cleanup + # It is needed for cleanup the last full_tensor in the refit process + self._held_streamed_param_reference = converted_params + + # Get device UUID for IPC + device_uuid = report_device_id() + # Create handles for the tensors + all_handles = [] + for key, p in converted_params.items(): + handle = reduce_tensor(p.detach()) + all_handles.append((key, handle)) + + #print(f"device_uuid: {device_uuid}, All handles keys: {[key for key, _ in all_handles]}") + print(f"device_uuid: {device_uuid}") + return {device_uuid: all_handles} + + @torch.no_grad() + def prepare_weights_for_ipc_refit( + self, _refit_buffer_size_gb: Optional[int] = None + ) -> list[list[str]]: + """Prepare the weights for IPC. + + Returns: + list: A list containing the keys of the parameters, which is grouped by size. + """ + # Get the state_dict_info and available memory from all workers + state_dict_info = self.refit_param_info + + if _refit_buffer_size_gb is not None: + total_available_bytes = _refit_buffer_size_gb * (1024**3) + else: + # Get the minimum available memory from all workers + total_available_bytes = min(result[1] for result in state_dict_info) + + # Group tensors by size + cur_available_bytes = total_available_bytes + grouped_param_keys: list[list[str]] = [] + keys: list[str] = [] + + for key, size_in_bytes in state_dict_info: + if size_in_bytes > cur_available_bytes: + if keys: + grouped_param_keys.append(keys) + keys = [] + cur_available_bytes = total_available_bytes + + keys.append(key) + cur_available_bytes -= size_in_bytes + + if keys: + grouped_param_keys.append(keys) + + return grouped_param_keys + +class NamedParam: + def __init__(self, name, size, param): + self.name = name + self.size = size + self.param = param + +class GateAndUp: + def __init__(self): + self.gate = None + self.up = None + def set_gate(self, gate): + self.gate = gate + def set_up(self, up): + self.up = up + def get_size(self): + return self.gate.size + self.up.size + def is_complete(self): + return self.gate is not None and self.up is not None + +class trtllm_interface: + def __init__(self, model_dir, tensor_parallel_size, placement_group=None, bundle_indices=None): + self.world_size = dist.get_world_size() + self.rank = dist.get_rank() + self.device = torch.device(f"cuda:{self.rank}") + self.model_dir = model_dir + self.tensor_parallel_size = tensor_parallel_size + self.placement_group = placement_group + self.bundle_indices = bundle_indices + + async def init_trtllm(self): + self.llm = await self.load_trtllm_model( + self.model_dir, + self.tensor_parallel_size, + self.placement_group, + self.bundle_indices + ) + + async def load_trtllm_model(self, model_dir, tensor_parallel_size, placement_group=None, bundle_indices=None): + if self.rank == 0: + print(f"Loading TensorRT-LLM model: {model_dir}, tensor_parallel_size: {tensor_parallel_size}") + print(f"placement_group: {placement_group}, bundle_indices: {bundle_indices}") + + # Save and clear distributed environment variables to avoid conflicts + # Ray orchestrator will set up its own process group in separate actors + saved_env = {} + env_vars_to_clear = ['LOCAL_RANK', 'RANK', 'WORLD_SIZE', 'LOCAL_WORLD_SIZE'] + for var in env_vars_to_clear: + if var in os.environ: + saved_env[var] = os.environ[var] + del os.environ[var] + + try: + # Build placement_where if placement_group and bundle_indices are provided + placement_where = None + + if placement_group is not None and bundle_indices is not None: + # Simple case: one placement group with specified bundle indices + placement_where = [(placement_group, bundle_indices)] + print(f"placement_where: {placement_where}") + + llm = AsyncLLM( + model=model_dir, + tensor_parallel_size=tensor_parallel_size, + orchestrator_type='ray', + ray_worker_extension_cls='tensorrt_llm.llmapi.rlhf_utils.WorkerExtension', + load_format='dummy', + #enable_sleep=True, # crash + kv_cache_config=KvCacheConfig( + free_gpu_memory_fraction=0.85, + enable_block_reuse=False + ), + placement_where=placement_where, + ) + await llm.async_init_phase() + finally: + # Restore environment variables + for var, value in saved_env.items(): + os.environ[var] = value + + return llm + else: + return None + + def update_weights_from_ipc_handles(self, rank, device_handles): + if rank == 0: + gathered_handles = [None for _ in range(dist.get_world_size())] + else: + gathered_handles = None + dist.gather_object( + obj=device_handles, + object_gather_list=gathered_handles, + dst=0 + ) + if rank == 0: + all_handles = {k: v for d in gathered_handles for k, v in d.items()} + result = self.llm._collective_rpc('update_weights', (all_handles, )) + return result + else: + return None + + def update_weights_from_tensor_generator(self, tensor_generator): + device_uuid = report_device_id() + rank = dist.get_rank() + from torch.multiprocessing.reductions import reduce_tensor + total_available_bytes = 0.7 * (1024**3) + cur_available_bytes = total_available_bytes + converted_params = {} + cur_handles = [] + gate_up = {} + stream_step = 0 + for name, param in tensor_generator: + size_in_bytes = param.element_size() * param.numel() + if isinstance(param, DTensor): + param = param.full_tensor() + gate_up_name = None + gate_up_pair = None + if "gate_proj" in name: + gate_up_name = name.replace("gate_proj", "") + if (gate_up_name not in gate_up): + gate_up[gate_up_name] = GateAndUp() + assert gate_up[gate_up_name].gate is None + gate_up[gate_up_name].set_gate(NamedParam(name, size_in_bytes, param)) + elif "up_proj" in name: + gate_up_name = name.replace("up_proj", "") + if (gate_up_name not in gate_up): + gate_up[gate_up_name] = GateAndUp() + assert gate_up[gate_up_name].up is None + gate_up[gate_up_name].set_up(NamedParam(name, size_in_bytes, param)) + if (gate_up_name is not None): + if gate_up[gate_up_name].is_complete(): + gate_up_pair = gate_up.pop(gate_up_name) + size_in_bytes = gate_up_pair.get_size() + else: + continue + + if size_in_bytes > cur_available_bytes: + stream_step += 1 + device_handles = {device_uuid: cur_handles} + print(f"stream_step: {stream_step}") + result = self.update_weights_from_ipc_handles(rank, device_handles) + print(f"update_weights_from_ipc_handles result: {result}") + cur_available_bytes = total_available_bytes + del converted_params + converted_params = {} + cur_handles = [] + + assert cur_available_bytes >= size_in_bytes + cur_available_bytes -= size_in_bytes + if (gate_up_pair is not None): + converted_params[gate_up_pair.gate.name] = gate_up_pair.gate.param + converted_params[gate_up_pair.up.name] = gate_up_pair.up.param + handle = reduce_tensor(gate_up_pair.gate.param.detach()) + cur_handles.append((gate_up_pair.gate.name, handle)) + handle = reduce_tensor(gate_up_pair.up.param.detach()) + cur_handles.append((gate_up_pair.up.name, handle)) + gate_up_pair = None + else: + converted_params[name] = param + handle = reduce_tensor(param.detach()) + cur_handles.append((name, handle)) + + assert len(gate_up) == 0 + + if cur_handles: + device_handles = {device_uuid: cur_handles} + stream_step += 1 + print(f"stream_step: {stream_step}") + result = self.update_weights_from_ipc_handles(rank, device_handles) + print(f"update_weights_from_ipc_handles result: {result}") + cur_available_bytes = total_available_bytes + del converted_params + converted_params = {} + cur_handles = [] + +def get_current_process_memory_info() -> int: + """ + Returns GPU memory usage for current process in bytes. + """ + # Get current process ID + current_pid = os.getpid() + # Get device handle for GPU 0 + device_handle = pynvml.nvmlDeviceGetHandleByIndex(0) + + # Get running processes + processes = pynvml.nvmlDeviceGetComputeRunningProcesses(device_handle) + + # Find current process + for process in processes: + if process.pid == current_pid: + return process.usedGpuMemory + + return 0 + +def get_current_mem_info(message: str = ""): + import nvsmi + mem_allocated = torch.cuda.memory_allocated() + mem_reserved = torch.cuda.memory_reserved() + mem_free, mem_total = torch.cuda.mem_get_info() + process_mem_info = get_current_process_memory_info() + print(f"{message} mem_free: {mem_free:,}, mem_total: {mem_total:,}, mem_allocated: {mem_allocated:,}, mem_reserved: {mem_reserved:,}, process_mem_info: {process_mem_info:,}") + for gpu in nvsmi.get_gpus(): + print(gpu) + return mem_free, mem_total, mem_allocated, mem_reserved, process_mem_info + +def get_total_available_bytes(pg: dist.ProcessGroup, message: str = "") -> int: + mem_allocated = torch.cuda.memory_allocated() + mem_reserved = torch.cuda.memory_reserved() + mem_free, mem_total = torch.cuda.mem_get_info() + print(f"{message} mem_free: {mem_free:,}, mem_total: {mem_total:,}, mem_allocated: {mem_allocated:,}, mem_reserved: {mem_reserved:,}") + mem_free = torch.tensor(mem_free) + dist.all_reduce(mem_free, op=dist.ReduceOp.MIN, group=pg) + mem_free = mem_free.item() + print(f"{message} gathered_mem_free: {mem_free:,}") + return mem_free * 0.2 + +def cleanup(): + """Cleanup function to destroy process group""" + if dist.is_initialized(): + print(f"Cleaning up process group on rank {dist.get_rank()}") + dist.destroy_process_group() + +async def async_worker(rank, world_size, model_dir, tensor_parallel_size, use_fsdp, placement_group=None, bundle_indices=None): + #os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7" + #os.environ["TRTLLM_RAY_BUNDLE_INDICES"] = "1,2,3,4,5,6,7" + os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" + os.environ["TRTLLM_RAY_BUNDLE_INDICES"] = "1,2" + #os.environ["TRTLLM_RAY_PER_WORKER_GPUS"] = "1" + + """Async worker function that runs the actual test logic within an event loop.""" + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + tags = ["sampler", + "drafter", + "guided_decoder", + "spec_resource_manager", + "_no_capture_model_extra", + "executor_extra", + "kv_cache", + "model", + "draft_model"] + + world_size, rank, device = init_distributed() + + sampling_params = SamplingParams(max_tokens=32) + + # Load FSDP model + fsdp = fsdp_interface(model_dir) + trtllm = trtllm_interface(model_dir, tensor_parallel_size, placement_group, bundle_indices) + await trtllm.init_trtllm() + + if rank == 0: + print(f"Collected handles from all {world_size} ranks:") + + # For FSDP mode, we would need additional logic to integrate withTensorRT-LLM + # This is a placeholder for now + if rank == 0: + for prompt in prompts: + outputs = await trtllm.llm.generate_async(prompt, sampling_params) + generated_text = outputs.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + ## load the model from fsdp + ## then generate the output again + ## get_current_mem_info("Before sleep") + ## result = trtllm.llm._collective_rpc('sleep', args=(tags,)) + ## print(f"sleep result: {result}") + ## get_current_mem_info("After sleep") +## + ## result = trtllm.llm._collective_rpc('wakeup', args=(tags,)) + ## print(f"wakeup result: {result}") + ## get_current_mem_info("After wakeup") + + trtllm.update_weights_from_tensor_generator(fsdp.per_tensor_generator()) + + # generate the output again + if rank == 0: + for prompt in prompts: + outputs = await trtllm.llm.generate_async(prompt, sampling_params) + generated_text = outputs.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + ## load the model from fsdp + ## then generate the output again + ## get_current_mem_info("Before sleep") + ## result = trtllm.llm._collective_rpc('sleep', args=(tags,)) + ## print(f"sleep result: {result}") + ## get_current_mem_info("After sleep") +## + ## result = trtllm.llm._collective_rpc('wakeup', args=(tags,)) + ## print(f"wakeup result: {result}") + ## get_current_mem_info("After wakeup") + + + ##trtllm.update_weights_from_tensor_generator(fsdp.per_tensor_generator()) +## + ### generate the output again + ##if rank == 0: + ## outputs = trtllm.llm.generate(prompts, sampling_params) + ## for i, output in enumerate(outputs): + ## prompt = output.prompt + ## generated_text = output.outputs[0].text + ## print(f"[{i}] Prompt: {prompt!r}, Generated text: {generated_text!r}") +## + exit_distributed() + +def worker(rank, world_size, master_port, model_dir, tensor_parallel_size, use_fsdp, placement_group=None, bundle_indices=None): + """Worker process entry point that sets up environment and runs async event loop.""" + # Set up environment variables for distributed training + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(master_port) + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["LOCAL_RANK"] = str(rank) + + # Create a new event loop for this process + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + # Run the async worker function + loop.run_until_complete( + async_worker(rank, world_size, model_dir, tensor_parallel_size, use_fsdp, placement_group, bundle_indices) + ) + finally: + # Clean up the event loop + loop.close() + +def main(): + parser = argparse.ArgumentParser( + description="LLM models with the PyTorch workflow.") + + parser.add_argument('--model_dir', + type=str, + required=True, + default='/model/Qwen2.5-0.5B-Instruct', + help="Model checkpoint directory.") + + parser.add_argument('--tensor_parallel_size', + type=int, + default=2, + help="Tensor parallel size (number of GPUs to use)") + + parser.add_argument('--use_fsdp', + action='store_true', + help="Use FSDP model loading instead of direct TensorRT-LLM loading") + + args = parser.parse_args() + + # Initialize Ray and create placement group + import ray + from ray.util.placement_group import placement_group as create_placement_group + + # Initialize Ray + ray.init(ignore_reinit_error=True) + print(f"Ray initialized. Available resources: {ray.cluster_resources()}") + + # Create placement group with bundles for all GPUs + # Each bundle gets 1.0 GPU (or adjust as needed for sharing) + n_gpus = 8 + bundles = [{"GPU": 1.0, "CPU": 1} for _ in range(n_gpus)] + + placement_group = create_placement_group( + bundles=bundles, + strategy="STRICT_PACK", # Keep all GPUs on same node + name=f"trtllm_pg_{n_gpus}gpus" + ) + + # Wait for placement group to be ready + ray.get(placement_group.ready()) + print(f"Placement group created with {n_gpus} bundles") + + # Use all bundles + bundle_indices = list(range(args.tensor_parallel_size)) + print(f"Using bundle indices: {bundle_indices}") + + world_size = args.tensor_parallel_size + master_port = get_free_port() + + # Spawn workers with placement group and bundle indices + mp.spawn( + worker, + args=(world_size, master_port, args.model_dir, args.tensor_parallel_size, + args.use_fsdp, placement_group, bundle_indices), + nprocs=world_size, + join=True + ) + + # Cleanup Ray + from ray.util.placement_group import remove_placement_group + if placement_group is not None: + remove_placement_group(placement_group) + ray.shutdown() + +if __name__ == '__main__': + main() + +#python3 examples/llm-api/rl_integration_test_async.py --model_dir /model/Qwen2.5-0.5B-Instruct --tensor_parallel_size 2 \ No newline at end of file diff --git a/tensorrt_llm/__init__.py b/tensorrt_llm/__init__.py index 978cf0796f1..cea56431b77 100644 --- a/tensorrt_llm/__init__.py +++ b/tensorrt_llm/__init__.py @@ -84,7 +84,7 @@ def _preload_python_lib(): from .builder import BuildConfig, Builder, BuilderConfig, build from .disaggregated_params import DisaggregatedParams from .functional import Tensor, constant -from .llmapi import LLM, MultimodalEncoder +from .llmapi import LLM, AsyncLLM, MultimodalEncoder from .llmapi.llm_args import LlmArgs, TorchLlmArgs, TrtLlmArgs from .logger import logger from .mapping import Mapping @@ -136,6 +136,7 @@ def _preload_python_lib(): 'quantization', 'tools', 'LLM', + 'AsyncLLM', 'MultimodalEncoder', 'LlmArgs', 'TorchLlmArgs', diff --git a/tensorrt_llm/executor/ray_executor.py b/tensorrt_llm/executor/ray_executor.py index ad8b838217e..9fd26fbb92f 100644 --- a/tensorrt_llm/executor/ray_executor.py +++ b/tensorrt_llm/executor/ray_executor.py @@ -1,3 +1,4 @@ +import asyncio import os from typing import Any, Dict, List, Optional, Tuple @@ -7,8 +8,7 @@ e.msg = """Cannot import Ray. Please install 'ray' package to use ray orchestrator""" raise -from ray.util.placement_group import (PlacementGroup, - PlacementGroupSchedulingStrategy, +from ray.util.placement_group import (PlacementGroupSchedulingStrategy, get_current_placement_group, placement_group) @@ -23,6 +23,7 @@ from .request import GenerationRequest from .result import GenerationResult, RayAsyncQueue, RaySyncQueue from .rpc_proxy import RpcExecutorMixin +from .utils import has_event_loop __all__ = [ "RayExecutor", @@ -78,14 +79,16 @@ def __init__(self, self.master_port = get_free_port() self.use_rpc = ray_use_rpc() - worker_kwargs = dict(**worker_kwargs, - postproc_worker_config=postproc_worker_config, - is_llm_executor=is_llm_executor) + self.worker_kwargs = dict( + **worker_kwargs, + postproc_worker_config=postproc_worker_config, + is_llm_executor=is_llm_executor) if self.use_rpc: self.init_rpc_executor() - worker_kwargs['rpc_addr'] = self.rpc_addr - self.create_workers(RayGPUWorker, worker_kwargs) + self.worker_kwargs['rpc_addr'] = self.rpc_addr + if not has_event_loop(): + self.init_workers_sync() self.setup_engine_remote() self.setup_mainloop(tasks=[self._fetch_responses_loop_async], thread_name="ray_executor_main_loop") @@ -107,7 +110,8 @@ def __init__(self, self.response_sync_queue) self.response_queue.warmup.remote() self.response_sync_queue.warmup.remote() - self.create_workers(RayGPUWorker, worker_kwargs) + if not has_event_loop(): + self.init_workers_sync() except Exception as e: self.shutdown() @@ -115,9 +119,13 @@ def __init__(self, raise e def create_workers(self, worker_cls, worker_kwargs): + llm_args = worker_kwargs.get("llm_args") + # When set to be a fraction, it allows Ray to schedule # multiple actors on a single GPU for colocate use cases. - num_gpus = float(os.getenv("TRTLLM_RAY_PER_WORKER_GPUS", "1.0")) + num_gpus = (llm_args.per_worker_gpu_share if llm_args + and llm_args.per_worker_gpu_share is not None else float( + os.getenv("TRTLLM_RAY_PER_WORKER_GPUS", "1.0"))) logger.debug(f"{num_gpus=} for each worker.") runtime_env = ray.runtime_env.RuntimeEnv() @@ -128,28 +136,40 @@ def create_workers(self, worker_cls, worker_kwargs): "MASTER_PORT": str(self.master_port) }) - self.placement_group, self.bundle_indices = self._get_placement_group( - tp_size=self.tp_size) + placement_groups, self.bundle_indices = self._get_placement_group( + tp_size=self.tp_size, worker_kwargs=worker_kwargs) + + if isinstance(placement_groups, list): + self.placement_group = None + else: + self.placement_group = placement_groups - self.workers = [ - RayWorkerWrapper.options( + self.workers = [] + for rank in range(self.world_size): + pg = placement_groups[rank] if isinstance( + placement_groups, list) else placement_groups + worker = RayWorkerWrapper.options( num_gpus=num_gpus, - runtime_env=runtime_env, # per-actor env + runtime_env=runtime_env, scheduling_strategy=PlacementGroupSchedulingStrategy( - placement_group=self.placement_group, + placement_group=pg, placement_group_bundle_index=self.bundle_indices[rank], )).remote(worker_cls, worker_kwargs, self.world_size, rank) - for rank in range(self.world_size) - ] + self.workers.append(worker) + + def init_workers_sync(self): + self.create_workers(RayGPUWorker, self.worker_kwargs) + try: + ray.get(self._get_worker_ready_futures()) + except ray.exceptions.ActorDiedError as e: + raise RuntimeError("RayGPUWorker died during initialization") from e + async def init_workers_async(self): + self.create_workers(RayGPUWorker, self.worker_kwargs) try: - ray.get([worker.__ray_ready__.remote() for worker in self.workers]) + await asyncio.gather(*self._get_worker_ready_futures()) except ray.exceptions.ActorDiedError as e: - if "The actor died because of an error raised in its creation task" in str( - e): - raise RuntimeError( - "RayGPUWorker died during initialization") from e - raise + raise RuntimeError("RayGPUWorker died during initialization") from e @unwrap_ray_errors() def call_all_ray_workers(self, func: str, leader_only: bool, @@ -316,15 +336,51 @@ def shutdown(self): logger.debug("Shutting down Ray cluster") ray.shutdown() - def _get_placement_group(self, - tp_size: int) -> Tuple[PlacementGroup, List[int]]: + def _get_worker_ready_futures(self): + return [worker.__ray_ready__.remote() for worker in self.workers] + + def _get_placement_group( + self, + tp_size: int, + worker_kwargs: Dict = None) -> Tuple[Any, List[int]]: """ Either use the existing placement group from driver script (e.g., in the case of RL FW integration), or create a default PACK placement group where each bundle has tp_size GPUs. - When tp_size ≤ GPUs per node, keep one TP group per node. - When tp_size > GPUs per node, allow a TP group span nodes. - rank 0 must be put on the driver node + + Returns: + Tuple of (placement_group(s), bundle_indices) + - placement_group(s) can be a single PlacementGroup or a List[PlacementGroup] + - bundle_indices is always a List[int] """ + llm_args = worker_kwargs.get("llm_args") if worker_kwargs else None + + if llm_args and hasattr( + llm_args, + 'placement_groups') and llm_args.placement_groups is not None: + total_workers = sum( + len(indices) for indices in llm_args.placement_bundle_indices) + if total_workers != self.world_size: + raise ValueError( + f"Total bundle indices ({total_workers}) must equal world_size ({self.world_size})" + ) + + logger.info( + f"Creating {self.world_size} workers with external placement groups" + ) + + flat_pgs = [] + flat_indices = [] + for pg, indices in zip(llm_args.placement_groups, + llm_args.placement_bundle_indices): + for idx in indices: + flat_pgs.append(pg) + flat_indices.append(idx) + + return flat_pgs, flat_indices + bundle_indices = os.getenv("TRTLLM_RAY_BUNDLE_INDICES", None) if bundle_indices: diff --git a/tensorrt_llm/executor/ray_gpu_worker.py b/tensorrt_llm/executor/ray_gpu_worker.py index 00dc1025f4d..7e87ded6345 100644 --- a/tensorrt_llm/executor/ray_gpu_worker.py +++ b/tensorrt_llm/executor/ray_gpu_worker.py @@ -44,7 +44,6 @@ class RayWorkerWrapper: def __init__(self, worker_cls, worker_kwargs, world_size, rank): self.master_address = os.environ["MASTER_ADDR"] self.master_port = os.environ["MASTER_PORT"] - # Ray can't pickle TensorRT logger global logger from tensorrt_llm.logger import logger diff --git a/tensorrt_llm/llmapi/__init__.py b/tensorrt_llm/llmapi/__init__.py index cb868d8d068..1023a152a57 100644 --- a/tensorrt_llm/llmapi/__init__.py +++ b/tensorrt_llm/llmapi/__init__.py @@ -2,7 +2,7 @@ from ..executor import CompletionOutput, LoRARequest, RequestError from ..sampling_params import GuidedDecodingParams, SamplingParams from .build_cache import BuildCacheConfig -from .llm import LLM, RequestOutput +from .llm import LLM, AsyncLLM, RequestOutput # yapf: disable from .llm_args import (AttentionDpConfig, AutoDecodingConfig, BatchingType, CacheTransceiverConfig, CalibConfig, @@ -23,6 +23,7 @@ __all__ = [ 'LLM', + 'AsyncLLM', 'MultimodalEncoder', 'CompletionOutput', 'RequestOutput', diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 32c6a90e327..44aa8ece8b0 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -48,6 +48,7 @@ # TODO[chunweiy]: move the following symbols back to utils scope, and remove the following import from .utils import (append_docstring, exception_handler, get_device_count, logger_debug, set_api_status) +from ray.util.placement_group import PlacementGroup, placement_group class RequestOutput(DetokenizedGenerationResultBase, GenerationResult): @@ -189,7 +190,7 @@ def __init__(self, self.mpi_session = self.args.mpi_session if self.args.parallel_config.is_multi_gpu: - if get_device_count( + if os.getenv("RAY_LOCAL_WORLD_SIZE") is None and get_device_count( ) < self.args.parallel_config.world_size_per_node: raise RuntimeError( f"Only {get_device_count()} GPUs are available, but {self.args.parallel_config.world_size} are required." @@ -225,7 +226,6 @@ def __init__(self, self.runtime_context: Optional[_ModelRuntimeContext] = None self.llm_build_stats = LlmBuildStats() - self._build_model() except Exception: @@ -1125,3 +1125,10 @@ def __init__(self, Parameters: """ + TORCH_LLM_DOCSTRING + +class AsyncLLM(LLM): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + async def async_init_phase(self): + await self._executor.init_workers_async() diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 40d967a2cb8..f547f482fc4 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -8,8 +8,9 @@ from dataclasses import dataclass from enum import Enum, EnumMeta from pathlib import Path -from typing import (Any, ClassVar, Dict, List, Literal, Optional, Set, Tuple, - Type, TypeAlias, TypeVar, Union, get_args, get_origin) +from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Literal, Optional, + Set, Tuple, Type, TypeAlias, TypeVar, Union, get_args, + get_origin) import torch import yaml @@ -19,6 +20,11 @@ from strenum import StrEnum from transformers import PreTrainedTokenizerBase +try: + from ray.util.placement_group import PlacementGroup +except ImportError: + PlacementGroup = None + from tensorrt_llm.lora_helper import (LoraConfig, get_default_trtllm_modules_to_hf_modules) @@ -1926,6 +1932,8 @@ def validate_dtype(cls, v, info): @field_validator("gpus_per_node", mode='before') @classmethod def validate_gpus_per_node(cls, v, info): + if os.getenv("RAY_LOCAL_WORLD_SIZE") is not None: + return info.data.get("tensor_parallel_size") if v is None: logger.warning( f"Using default gpus_per_node: {torch.cuda.device_count()}") @@ -2701,6 +2709,26 @@ class TorchLlmArgs(BaseLlmArgs): "Allows users to extend the functions of the RayGPUWorker class.", status="prototype") + # Ray placement group config. Namings TBD. + placement_groups: Optional[List[Any]] = Field( + default=None, + description="List of Ray placement groups, one per node. " + "Each element must be a ray.util.placement_group.PlacementGroup instance.", + exclude_from_json=True, + status="prototype") + + placement_bundle_indices: Optional[List[List[int]]] = Field( + default=None, + description="List of bundle indices for each placement group. " + "Outer list corresponds to placement_groups, inner list contains bundle indices for that group. ", + status="prototype") + + per_worker_gpu_share: Optional[float] = Field( + default=None, + description="GPU fraction per worker for colocation scenarios. " + "Example: 0.1 means 10 actors can share one GPU. Defaults to 1.0 (one actor per GPU).", + status="prototype") + enable_sleep: bool = Field( default=False, description= @@ -2945,6 +2973,44 @@ def validate_ray_worker_extension_cls(self) -> 'TorchLlmArgs': ) return self + @model_validator(mode='after') + def validate_ray_placement_config(self) -> 'TorchLlmArgs': + has_pgs = self.placement_groups is not None + has_indices = self.placement_bundle_indices is not None + + if (has_pgs or has_indices) and self.orchestrator_type != "ray": + raise ValueError( + "placement_groups is only supported with orchestrator_type='ray'" + ) + + if has_pgs != has_indices: + raise ValueError( + "placement_groups and placement_bundle_indices must be provided together" + ) + + if has_pgs: + if len(self.placement_groups) != len(self.placement_bundle_indices): + raise ValueError( + f"placement_groups length ({len(self.placement_groups)}) must equal " + f"placement_bundle_indices length ({len(self.placement_bundle_indices)})" + ) + + if self.per_worker_gpu_share is not None: + if not (0 < self.per_worker_gpu_share <= 1.0): + raise ValueError( + f"per_worker_gpu_share must be between 0 and 1.0, " + f"got {self.per_worker_gpu_share}") + + if has_pgs: + if PlacementGroup is not None: + for i, pg in enumerate(self.placement_groups): + if not isinstance(pg, PlacementGroup): + raise TypeError( + f"placement_groups[{i}] must be a Ray PlacementGroup, " + f"got {type(pg).__name__}") + + return self + def get_executor_config( self, _hf_model_dir: Optional[Path] = None, diff --git a/tensorrt_llm/llmapi/rlhf_utils.py b/tensorrt_llm/llmapi/rlhf_utils.py index b3d63ec236e..ad87cce1c07 100644 --- a/tensorrt_llm/llmapi/rlhf_utils.py +++ b/tensorrt_llm/llmapi/rlhf_utils.py @@ -3,6 +3,8 @@ from tensorrt_llm._ray_utils import control_action_decorator from tensorrt_llm._torch.utils import get_device_uuid from tensorrt_llm.logger import logger +import pickle +import base64 class WorkerExtension: @@ -52,7 +54,7 @@ def update_weights(self, ipc_handles: dict): raise ValueError(f"Device UUID {device_uuid} not found in ipc_handles") weights = {} - all_handles = ipc_handles[device_uuid] + all_handles = pickle.loads(base64.b64decode(ipc_handles[device_uuid])) for param_name, tensor_handle in all_handles: func, args = tensor_handle diff --git a/tensorrt_llm/serve/openai_protocol.py b/tensorrt_llm/serve/openai_protocol.py index bab8025fda1..15672b52bc1 100644 --- a/tensorrt_llm/serve/openai_protocol.py +++ b/tensorrt_llm/serve/openai_protocol.py @@ -3,7 +3,7 @@ import base64 import time import uuid -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union, Tuple import torch import xgrammar @@ -931,6 +931,14 @@ class ResponsesStreamResponse(OpenAIBaseModel): "response.incomplete"] +class MemoryUpdateRequest(OpenAIBaseModel): + tags: List[str] = Field(default=["model", "kv_cache"]) + + +class UpdateWeightsRequest(OpenAIBaseModel): + weights: Dict[str, str] + + def encode_opaque_state(opaque_state: Optional[bytes]) -> Optional[str]: if opaque_state is None: return None diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index 05facda203a..a76adba321b 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -47,7 +47,9 @@ ErrorResponse, ModelCard, ModelList, PromptTokensDetails, ResponsesRequest, UsageInfo, - to_llm_disaggregated_params) + to_llm_disaggregated_params, + MemoryUpdateRequest, + UpdateWeightsRequest) from tensorrt_llm.serve.postprocess_handlers import ( ChatCompletionPostprocArgs, ChatPostprocArgs, CompletionPostprocArgs, chat_harmony_post_processor, chat_harmony_streaming_post_processor, @@ -255,6 +257,15 @@ def register_routes(self): self.app.add_api_route("/v1/responses", self.openai_responses, methods=["POST"]) + self.app.add_api_route("/release_memory", + self.release_memory, + methods=["POST"]) + self.app.add_api_route("/resume_memory", + self.resume_memory, + methods=["POST"]) + self.app.add_api_route("/update_weights", + self.update_weights, + methods=["POST"]) if self.llm.args.return_perf_metrics: # register /prometheus/metrics self.mount_metrics() @@ -291,6 +302,15 @@ def register_mm_encoder_routes(self): self.app.add_api_route("/v1/chat/completions", self.openai_mm_encoder, methods=["POST"]) + self.app.add_api_route("/release_memory", + self.release_memory, + methods=["POST"]) + self.app.add_api_route("/resume_memory", + self.resume_memory, + methods=["POST"]) + self.app.add_api_route("/update_weights", + self.update_weights, + methods=["POST"]) async def health(self) -> Response: return Response(status_code=200) @@ -975,6 +995,42 @@ async def create_stream_response(generator, request: ResponsesRequest, sampling_ return JSONResponse(content={"detail": "None"}) + async def release_memory(self, request: MemoryUpdateRequest) -> JSONResponse: + #await self.llm.sleep_async(level=2) + + tags = ["sampler", + "drafter", + "guided_decoder", + "spec_resource_manager", + "_no_capture_model_extra", + "executor_extra", + "kv_cache", + "model", + "draft_model"] + #self.llm._collective_rpc('sleep', args=(tags,)) + print(f"HTTP received: release_memory {tags}") + return JSONResponse(content={"status": "success"}) + + async def resume_memory(self, request: MemoryUpdateRequest) -> JSONResponse: + #await self.llm.wakeup_async(level=2) + tags = ["sampler", + "drafter", + "guided_decoder", + "spec_resource_manager", + "_no_capture_model_extra", + "executor_extra", + "kv_cache", + "model", + "draft_model"] + #self.llm._collective_rpc('wakeup', args=(tags,)) + print(f"HTTP received: resume_memory {tags}") + return JSONResponse(content={"status": "success"}) + + async def update_weights(self, request: UpdateWeightsRequest) -> JSONResponse: + #await self.llm.update_weights_from_ipc_handles_async(request.weights) + print(f"HTTP received: update_weights") + self.llm._collective_rpc('update_weights', args=(request.weights,)) + return JSONResponse(content={"status": "success"}) async def __call__(self, host, port): # Store the binding address for server registration diff --git a/tests/unittest/_torch/ray_orchestrator/multi_gpu/test_executor.py b/tests/unittest/_torch/ray_orchestrator/multi_gpu/test_executor.py index 492d7b08182..232513d9a85 100644 --- a/tests/unittest/_torch/ray_orchestrator/multi_gpu/test_executor.py +++ b/tests/unittest/_torch/ray_orchestrator/multi_gpu/test_executor.py @@ -27,9 +27,7 @@ def test_worker_extension(): @pytest.mark.gpu4 -def test_bundle_indices(monkeypatch): - """Placement via bundle indices""" - +def test_placement_env_vars(monkeypatch): monkeypatch.setenv("RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", "1") monkeypatch.setenv("TLLM_RAY_USE_RPC", "1") @@ -78,6 +76,53 @@ def test_bundle_indices(monkeypatch): ray.shutdown() +@pytest.mark.gpu2 +# @pytest.mark.gpu4 +@pytest.mark.parametrize( + "n_gpus,bundle_indices", + [ + (2, [1]), + # (4, [2, 3]), + ], + ids=["gpu2_tp1"] # , "gpu4_tp2" +) +def test_placement_api(monkeypatch, n_gpus, bundle_indices): + monkeypatch.setenv("RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", "1") + monkeypatch.setenv("TLLM_RAY_USE_RPC", "1") + + tp_size = n_gpus // 2 + pg = None + try: + ray.init() + pg = placement_group([{"GPU": 1, "CPU": 1}] * n_gpus) + ray.get(pg.ready()) + print(f"Placement group ready with bundles {pg.bundle_specs}") + + llm = LLM( + model=os.path.join(llm_models_root(), "llama-models-v2", + "TinyLlama-1.1B-Chat-v1.0"), + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.1), + tensor_parallel_size=tp_size, + orchestrator_type="ray", + placement_groups=[pg], + placement_bundle_indices=[bundle_indices], + per_worker_gpu_share=0.8, + ) + + inference_actor_uuids = llm._collective_rpc("report_device_id") + expected_uuids = [get_device_uuid(idx) for idx in bundle_indices] + + print(f"{inference_actor_uuids=}, all_uuids={[get_device_uuid(i) for i in range(n_gpus)]}") + + assert sorted(inference_actor_uuids) == sorted(expected_uuids), \ + f"Workers not placed on expected GPUs. Expected: {expected_uuids}, Got: {inference_actor_uuids}" + + finally: + if pg is not None: + remove_placement_group(pg) + ray.shutdown() + + @pytest.mark.gpu2 def test_cuda_visible_device(monkeypatch): """Placement via cuda_visible_device"""