diff --git a/tensorrt_llm/_torch/autotuner.py b/tensorrt_llm/_torch/autotuner.py index 33ef41af8a9..cbcc2926805 100644 --- a/tensorrt_llm/_torch/autotuner.py +++ b/tensorrt_llm/_torch/autotuner.py @@ -1506,7 +1506,7 @@ def _broadcast_cache_data( """Broadcast tactics from root rank to all other ranks.""" cache_data = self.profiling_cache.get_specific_custom_op(custom_op) root = 0 - cache_data = self._dist.tp_broadcast(obj=cache_data, root=root) + cache_data = self._dist.tp_cp_broadcast(obj=cache_data, root=root) self.profiling_cache.merge_cache_data(cache_data) diff --git a/tensorrt_llm/_torch/distributed/communicator.py b/tensorrt_llm/_torch/distributed/communicator.py index 7b6af1188b0..09bbc234ee2 100644 --- a/tensorrt_llm/_torch/distributed/communicator.py +++ b/tensorrt_llm/_torch/distributed/communicator.py @@ -116,6 +116,26 @@ def broadcast(self, obj, root=0): def allgather(self, obj, root=0): pass + @abstractmethod + def tp_broadcast(self, obj, root=0, **kwargs): + pass + + @abstractmethod + def cp_broadcast(self, obj, root=0, **kwargs): + pass + + def tp_cp_broadcast(self, obj, root=0, **kwargs): + """Broadcast object across both TP and CP groups. + + This is used when both TP and CP parallelism are enabled (e.g., helix parallelism). + First broadcasts within the TP group, then within the CP group. + """ + if self.tp_size > 1: + obj = self.tp_broadcast(obj, root=root, **kwargs) + if self.cp_size > 1: + obj = self.cp_broadcast(obj, root=root, **kwargs) + return obj + def safe_broadcast(comm, obj, root=0, chunk_size: int = 4 * 1024 * 1024): """ @@ -407,6 +427,14 @@ def create_cp_comm(self): def cp_allgather(self, obj): return self.cp_comm.allgather(obj) + def cp_broadcast(self, + obj, + root=0, + chunk_size: int = 4 * 1024 * 1024, + **kwargs): + comm = self.cp_comm + return safe_broadcast(comm, obj, root=root, chunk_size=chunk_size) + def tp_allgather(self, obj): return self.tp_comm.allgather(obj) @@ -414,7 +442,11 @@ def tp_gather(self, obj, root=0, chunk_size: int = 4 * 1024 * 1024): comm = self.tp_comm return safe_gather(comm, obj, root=root, chunk_size=chunk_size) - def tp_broadcast(self, obj, root=0, chunk_size: int = 4 * 1024 * 1024): + def tp_broadcast(self, + obj, + root=0, + chunk_size: int = 4 * 1024 * 1024, + **kwargs): comm = self.tp_comm return safe_broadcast(comm, obj, root=root, chunk_size=chunk_size) @@ -699,7 +731,7 @@ def tp_gather(self, obj, dst=0): return output_list @log_op - def tp_broadcast(self, obj, root=0): + def tp_broadcast(self, obj, root=0, **kwargs): if isinstance(obj, torch.Tensor): dist.broadcast(obj, src=root, group=self.mapping.tp_group_pg) return obj @@ -712,6 +744,20 @@ def tp_broadcast(self, obj, root=0): device=torch.device("cpu")) return ret[0] + @log_op + def cp_broadcast(self, obj, root=0, **kwargs): + if isinstance(obj, torch.Tensor): + dist.broadcast(obj, src=root, group=self.mapping.cp_group_pg) + return obj + else: + ret = [obj] + torch.distributed.broadcast_object_list( + ret, + src=root, + group=self.mapping.cp_group_pg, + device=torch.device("cpu")) + return ret[0] + @log_op def pp_allgather(self, obj): if isinstance(obj, torch.Tensor): diff --git a/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py b/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py index 120c42dbd2c..161282e4c4e 100644 --- a/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py +++ b/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py @@ -587,9 +587,10 @@ def _broadcast_new_requests( if not self.dist.has_pp: return self.dist.broadcast(payloads, root=0) - # Broadcast within first tp group before send/recv chain to other tp groups - if self.dist.tp_size > 1 and self.dist.is_first_pp_rank: - payloads = self.dist.tp_broadcast(payloads, root=0) + # Broadcast within first PP stage before send/recv chain to other PP stages. + # This needs to cover both TP and CP ranks within the first PP stage. + if self.dist.is_first_pp_rank: + payloads = self.dist.tp_cp_broadcast(payloads, root=0) # Tag for communication tag = self.dist.pp_size # Use pp_size as tag to avoid conflicts diff --git a/tests/integration/defs/accuracy/test_disaggregated_serving.py b/tests/integration/defs/accuracy/test_disaggregated_serving.py index c52911fe00b..8c7111f4490 100644 --- a/tests/integration/defs/accuracy/test_disaggregated_serving.py +++ b/tests/integration/defs/accuracy/test_disaggregated_serving.py @@ -871,7 +871,9 @@ def test_auto_dtype(self, overlap_scheduler, mtp_nextn): task = GSM8K(self.MODEL_NAME) task.evaluate(llm) - @pytest.mark.skip_less_device(4) + @pytest.mark.skip_less_device(8) + @pytest.mark.parametrize("gen_pp,gen_tp,gen_cp", [(1, 2, 2), (2, 1, 2)], + ids=["pp1tp2cp2", "pp2tp1cp2"]) @pytest.mark.parametrize("cuda_graph_config", [ None, { @@ -888,8 +890,10 @@ def test_auto_dtype(self, overlap_scheduler, mtp_nextn): "cudagraph:with_padding" ]) @pytest.mark.parametrize("comms_medium", ["fifo", "nccl"]) - def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config): + def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config, + gen_pp, gen_tp, gen_cp): use_nccl_for_alltoall = comms_medium == "nccl" + gen_ep = gen_tp * gen_cp kv_cache_config = { "free_gpu_memory_fraction": 0.5, "enable_block_reuse": False, @@ -898,20 +902,22 @@ def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config): } ctx_server_config = { "pipeline_parallel_size": 1, - "tensor_parallel_size": 2, + "tensor_parallel_size": 4, "context_parallel_size": 1, "disable_overlap_scheduler": True, "kv_cache_config": kv_cache_config, "enable_chunked_prefill": False, "cuda_graph_config": None, "cache_transceiver_config": { - "backend": "UCX" + "backend": "UCX", + "max_tokens_in_buffer": 8192, }, } gen_server_config = { - "tensor_parallel_size": 1, - "pipeline_parallel_size": 1, - "context_parallel_size": 2, + "tensor_parallel_size": gen_tp, + "pipeline_parallel_size": gen_pp, + "context_parallel_size": gen_cp, + "moe_expert_parallel_size": gen_ep, "cp_config": { "cp_type": "HELIX", "tokens_per_block": 32, @@ -922,7 +928,8 @@ def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config): "enable_chunked_prefill": False, "cuda_graph_config": cuda_graph_config, "cache_transceiver_config": { - "backend": "UCX" + "backend": "UCX", + "max_tokens_in_buffer": 8192, }, } disaggregated_server_config = { diff --git a/tests/integration/test_lists/qa/llm_function_core.txt b/tests/integration/test_lists/qa/llm_function_core.txt index 9a7102142bf..e185f40fdd4 100644 --- a/tests/integration/test_lists/qa/llm_function_core.txt +++ b/tests/integration/test_lists/qa/llm_function_core.txt @@ -540,14 +540,14 @@ accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=2] accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=0] accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=2] -accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none] -accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:without_padding] -accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding] -accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none] -accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:without_padding] -accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding] accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False] accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True] +accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none-pp2tp1cp2] +accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:without_padding-pp2tp1cp2] +accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp2tp1cp2] +accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none-pp2tp1cp2] +accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:without_padding-pp2tp1cp2] +accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2] accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[True] accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[False] accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_w4a8_mxfp4[fp8-latency] diff --git a/tests/integration/test_lists/test-db/l0_dgx_b200.yml b/tests/integration/test_lists/test-db/l0_dgx_b200.yml index c691acc1fef..04c2285858d 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_b200.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_b200.yml @@ -66,6 +66,7 @@ l0_dgx_b200: backend: pytorch orchestrator: mpi tests: + - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60) - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput] TIMEOUT (60) - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_mtp] TIMEOUT (60) - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_bs8_mtp] TIMEOUT (60) @@ -92,6 +93,7 @@ l0_dgx_b200: backend: pytorch orchestrator: mpi tests: + - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60) - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case TIMEOUT (60) - accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[baseline_fp8kv] TIMEOUT (60) - accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[latency] TIMEOUT (60) diff --git a/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml b/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml index bf0ca261fe0..2241aea415a 100644 --- a/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml +++ b/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml @@ -72,8 +72,6 @@ l0_gb200_multi_gpus: - accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4-cutlass] - accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[no_cuda_graph_overlap-cutlass] - accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4-trtllm] - - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none] - - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none] - accuracy/test_llm_api_pytorch.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm_eagle] TIMEOUT (90) - condition: ranges: @@ -89,10 +87,6 @@ l0_gb200_multi_gpus: stage: post_merge backend: pytorch tests: - - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:without_padding] - - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding] - - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:without_padding] - - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=True-attn_backend=FLASHINFER-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=False-attn_backend=TRTLLM-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 1a85060db3e..c9643955df4 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -526,5 +526,17 @@ accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_sof accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.9] SKIP (https://nvbugs/5774869) triton_server/test_triton.py::test_llava_onevision[llava_onevision] SKIP (https://nvbugs/5775205) triton_server/test_triton.py::test_gpt_ib_lad[gpt-ib-lad] SKIP (https://nvbugs/5775223) +accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4[torch_compile=False] SKIP (https://nvbugs/5777044) +accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4[torch_compile=True] SKIP (https://nvbugs/5777044) accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_trtllm-torch_compile=True] SKIP (https://nvbugs/5740377) unittest/_torch/modules/test_fused_moe.py::test_fused_moe_fp8_blockwise_cute_dsl_multi_gpu[MoEWeightLoadingMode.FUSED_GATE_UP_PROJ-DefaultMoeRoutingMethod-1] SKIP (https://nvbugs/5775256) +unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py::test_ep_shard[3-2] SKIP (https://nvbugs/5777041) +unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py::test_ep_shard[8-2] SKIP (https://nvbugs/5777041) +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_2gpus[cutlass-one_model-overlap_scheduler] SKIP (https://nvbugs/5777044) +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[trtllm-one_model-overlap_scheduler] SKIP (https://nvbugs/5777044) +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-trtllm-auto] SKIP (https://nvbugs/5777044) +accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=True] SKIP (https://nvbugs/5777044) +accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4[torch_compile=False] SKIP (https://nvbugs/5777044) +accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4[torch_compile=True] SKIP (https://nvbugs/5777044) +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[cutlass-one_model-overlap_scheduler] SKIP (https://nvbugs/5777044) +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-trtllm-auto] SKIP (https://nvbugs/5777044) diff --git a/tests/unittest/_torch/distributed/test_cp_broadcast.py b/tests/unittest/_torch/distributed/test_cp_broadcast.py new file mode 100644 index 00000000000..d5ed278726e --- /dev/null +++ b/tests/unittest/_torch/distributed/test_cp_broadcast.py @@ -0,0 +1,265 @@ +""" +Tests for cp_broadcast functionality in both MPIDist and TorchDist. + +This module tests the context parallelism broadcast operation which is used +when CP (context parallelism) is enabled (e.g., in Helix parallelism). + +For MPIDist tests, run with mpirun: +mpirun -n 2 python -m pytest tests/unittest/_torch/distributed/test_cp_broadcast.py -v + +For TorchDist tests, see test_ops.py which uses Ray for distributed testing. +""" + +import numpy as np +import pytest + +from tensorrt_llm._torch.distributed import MPIDist +from tensorrt_llm.mapping import Mapping + + +def get_mpi_info(): + """Get MPI rank and world size, returns (0, 1) if MPI is not available.""" + try: + from mpi4py import MPI + + comm = MPI.COMM_WORLD + return comm.Get_rank(), comm.Get_size() + except ImportError: + return 0, 1 + + +def skip_if_not_mpi(): + """Skip test if not running under MPI with sufficient ranks.""" + rank, world_size = get_mpi_info() + if world_size < 2: + pytest.skip("Test requires at least 2 MPI ranks (run with mpirun -n 2)") + + +class TestMPIDistCpBroadcast: + """Tests for MPIDist.cp_broadcast functionality.""" + + @pytest.fixture(autouse=True) + def setup(self): + """Set up MPI environment and mapping for each test.""" + skip_if_not_mpi() + self.rank, self.world_size = get_mpi_info() + + # Set up mapping with CP enabled (cp_size = world_size, tp_size = 1) + self.mapping = Mapping( + world_size=self.world_size, + rank=self.rank, + tp_size=1, + cp_size=self.world_size, + pp_size=1, + ) + self.dist = MPIDist(mapping=self.mapping) + + def test_broadcast_numpy_array(self): + """Test broadcasting a numpy array via cp_broadcast.""" + root = 0 + shape = (64, 128) + + if self.mapping.cp_rank == root: + # Root rank creates the data to broadcast + data = np.random.randn(*shape).astype(np.float32) + else: + # Non-root ranks have empty/zero data + data = np.zeros(shape, dtype=np.float32) + + # Store original data from root for verification + from mpi4py import MPI + + expected = np.zeros(shape, dtype=np.float32) + MPI.COMM_WORLD.Bcast(data if self.mapping.cp_rank == root else expected, root=root) + if self.mapping.cp_rank == root: + expected = data.copy() + + # Perform cp_broadcast + result = self.dist.cp_broadcast(data, root=root) + + # Verify all ranks have the same data + np.testing.assert_array_almost_equal(result, expected) + + def test_broadcast_python_dict(self): + """Test broadcasting a Python dictionary via cp_broadcast.""" + root = 0 + + if self.mapping.cp_rank == root: + obj = { + "model_name": "llama", + "batch_size": 32, + "tokens": [1, 2, 3, 4, 5], + "config": {"hidden_size": 4096, "num_layers": 32}, + } + else: + obj = None + + result = self.dist.cp_broadcast(obj, root=root) + + # Verify all ranks received the correct object + assert result["model_name"] == "llama" + assert result["batch_size"] == 32 + assert result["tokens"] == [1, 2, 3, 4, 5] + assert result["config"]["hidden_size"] == 4096 + assert result["config"]["num_layers"] == 32 + + def test_broadcast_python_list(self): + """Test broadcasting a Python list via cp_broadcast.""" + root = 0 + + if self.mapping.cp_rank == root: + obj = ["request1", "request2", {"id": 123, "data": [1, 2, 3]}] + else: + obj = None + + result = self.dist.cp_broadcast(obj, root=root) + + assert result == ["request1", "request2", {"id": 123, "data": [1, 2, 3]}] + + def test_broadcast_from_non_zero_root(self): + """Test broadcasting from a non-zero root rank.""" + if self.world_size < 2: + pytest.skip("Need at least 2 ranks to test non-zero root") + + root = 1 # Broadcast from rank 1 + + if self.mapping.cp_rank == root: + obj = {"source": "rank1", "value": 42} + else: + obj = None + + result = self.dist.cp_broadcast(obj, root=root) + + assert result["source"] == "rank1" + assert result["value"] == 42 + + def test_broadcast_large_object(self): + """Test broadcasting a large object that may require chunking.""" + root = 0 + # Create a large list to test chunking behavior + large_size = 100000 + + if self.mapping.cp_rank == root: + obj = list(range(large_size)) + else: + obj = None + + result = self.dist.cp_broadcast(obj, root=root) + + assert len(result) == large_size + assert result[0] == 0 + assert result[-1] == large_size - 1 + + def test_broadcast_string(self): + """Test broadcasting a simple string via cp_broadcast.""" + root = 0 + + if self.mapping.cp_rank == root: + obj = "Hello from root rank!" + else: + obj = None + + result = self.dist.cp_broadcast(obj, root=root) + + assert result == "Hello from root rank!" + + +# Additional integration-style test that can be run standalone +def test_mpi_cp_broadcast_integration(): + """ + Integration test for MPIDist cp_broadcast. + """ + rank, world_size = get_mpi_info() + if world_size < 2: + pytest.skip("Test requires at least 2 MPI ranks") + + # Create mapping with CP enabled + mapping = Mapping( + world_size=world_size, + rank=rank, + tp_size=1, + cp_size=world_size, + pp_size=1, + ) + dist = MPIDist(mapping=mapping) + + # Test 1: Broadcast dict + if mapping.cp_rank == 0: + payload = {"requests": [{"id": i} for i in range(10)]} + else: + payload = None + + result = dist.cp_broadcast(payload, root=0) + assert len(result["requests"]) == 10 + assert result["requests"][0]["id"] == 0 + + # Test 2: Broadcast numpy array + shape = (32, 64) + if mapping.cp_rank == 0: + arr = np.ones(shape, dtype=np.float32) * (rank + 1) + else: + arr = np.zeros(shape, dtype=np.float32) + + result = dist.cp_broadcast(arr, root=0) + expected_val = 1.0 # From rank 0 + np.testing.assert_array_almost_equal(result, np.ones(shape) * expected_val) + + +if __name__ == "__main__": + # Allow running directly with mpirun + pytest.main([__file__, "-v"]) + + +class TestMPIDistTpCpBroadcast: + """Tests for MPIDist.tp_cp_broadcast functionality.""" + + @pytest.fixture(autouse=True) + def setup(self): + """Set up MPI environment and mapping for each test.""" + skip_if_not_mpi() + self.rank, self.world_size = get_mpi_info() + + # Set up mapping with both TP and CP enabled + # For 2 ranks: tp_size=1, cp_size=2 (tp_cp_broadcast will only do cp_broadcast) + self.mapping = Mapping( + world_size=self.world_size, + rank=self.rank, + tp_size=1, + cp_size=self.world_size, + pp_size=1, + ) + self.dist = MPIDist(mapping=self.mapping) + + def test_tp_cp_broadcast_python_dict(self): + """Test broadcasting a Python dictionary via tp_cp_broadcast.""" + root = 0 + + # Only rank 0 in both TP and CP groups should have the object + if self.mapping.tp_rank == root and self.mapping.cp_rank == root: + obj = { + "model_name": "llama", + "batch_size": 32, + "tokens": [1, 2, 3, 4, 5], + } + else: + obj = None + + result = self.dist.tp_cp_broadcast(obj, root=root) + + # Verify all ranks received the correct object + assert result["model_name"] == "llama" + assert result["batch_size"] == 32 + assert result["tokens"] == [1, 2, 3, 4, 5] + + def test_tp_cp_broadcast_python_list(self): + """Test broadcasting a Python list via tp_cp_broadcast.""" + root = 0 + + if self.mapping.tp_rank == root and self.mapping.cp_rank == root: + obj = ["request1", "request2", {"id": 123, "data": [1, 2, 3]}] + else: + obj = None + + result = self.dist.tp_cp_broadcast(obj, root=root) + + assert result == ["request1", "request2", {"id": 123, "data": [1, 2, 3]}] diff --git a/tests/unittest/_torch/ray_orchestrator/multi_gpu/test_ops.py b/tests/unittest/_torch/ray_orchestrator/multi_gpu/test_ops.py index 18ec8c81cfc..cbce606293f 100644 --- a/tests/unittest/_torch/ray_orchestrator/multi_gpu/test_ops.py +++ b/tests/unittest/_torch/ray_orchestrator/multi_gpu/test_ops.py @@ -258,3 +258,192 @@ def test_allreduce_pg_op(setup_ray_cluster, seq_len, hidden_size): ]) for r in results: assert r is True + + +@ray.remote(num_gpus=1) +class CpBroadcastTest: + """Test worker for cp_broadcast operations with context parallelism.""" + + def __init__(self, rank, world_size, tp_size, cp_size): + self.rank = rank + self.world_size = world_size + self.tp_size = tp_size + self.cp_size = cp_size + self.master_address = os.environ["MASTER_ADDR"] + + assert len(ray.get_gpu_ids()) == 1 + self.gpu = int(ray.get_gpu_ids()[0]) + from tensorrt_llm.executor.ray_gpu_worker import RayWorkerWrapper + local_gpu = RayWorkerWrapper.physical_to_local_id(self.gpu) + torch.cuda.set_device(local_gpu) + + def _create_tcp_store(self, + port: Optional[int] = None + ) -> torch.distributed.TCPStore: + actual_port = port if port is not None else 0 + return torch.distributed.TCPStore(host_name=self.master_address, + port=actual_port, + world_size=self.world_size, + is_master=(self.rank == 0), + wait_for_workers=False) + + def setup_tcp_store(self): + if self.rank != 0: + raise RuntimeError("Only the master worker can setup TCP store") + self.store = self._create_tcp_store() + return self.store.port + + def setup_distributed_env(self, port: int): + if self.rank != 0: + self.store = self._create_tcp_store(port) + + torch.distributed.init_process_group(backend="cuda:nccl,cpu:gloo", + store=self.store, + world_size=self.world_size, + rank=self.rank) + self.mapping = Mapping(world_size=self.world_size, + gpus_per_node=self.world_size, + tp_size=self.tp_size, + cp_size=self.cp_size, + rank=self.rank) + self.dist = TorchDist(self.mapping) + + def run_tensor_broadcast(self, root_tensor: torch.Tensor, root: int = 0): + """Test broadcasting a tensor via cp_broadcast.""" + cp_rank = self.mapping.cp_rank + if cp_rank == root: + # Root rank has the tensor to broadcast. + tensor = root_tensor.cuda() + else: + # Non-root ranks start with zeros. + tensor = torch.zeros_like(root_tensor).cuda() + + result = self.dist.cp_broadcast(tensor, root=root) + + # After broadcast, all CP ranks should have the same tensor. + expected = root_tensor.cuda() + return torch.allclose(result, expected) + + def run_object_broadcast(self, root_obj, root: int = 0): + """Test broadcasting a non-tensor object via cp_broadcast.""" + cp_rank = self.mapping.cp_rank + if cp_rank == root: + obj = root_obj + else: + obj = None + + result = self.dist.cp_broadcast(obj, root=root) + + # After broadcast, all CP ranks should have the same object. + return result == root_obj + + def run_tp_cp_broadcast(self, root_obj, root: int = 0): + """Test broadcasting an object via tp_cp_broadcast.""" + # For tp_cp_broadcast, only rank 0 in both TP and CP should have the object. + tp_rank = self.mapping.tp_rank + cp_rank = self.mapping.cp_rank + if tp_rank == root and cp_rank == root: + obj = root_obj + else: + obj = None + + result = self.dist.tp_cp_broadcast(obj, root=root) + + # After broadcast, all TP and CP ranks should have the same object. + return result == root_obj + + +@pytest.mark.gpu2 +@pytest.mark.parametrize("hidden_size", [128, 512], ids=lambda x: f"hidden:{x}") +@pytest.mark.parametrize("seq_len", [16, 32], ids=lambda x: f"seqlen:{x}") +def test_cp_broadcast_tensor(setup_ray_cluster, seq_len, hidden_size): + """Test TorchDist.cp_broadcast with tensor data.""" + torch.manual_seed(42) + dtype = torch.bfloat16 + world_size = 2 + tp_size = 1 + cp_size = 2 # Enable context parallelism. + + # Create tensor to broadcast from root. + root_tensor = torch.randn((seq_len, hidden_size), dtype=dtype) + + runtime_env = ray.runtime_env.RuntimeEnv() + runtime_env["env_vars"] = os.environ.copy() + runtime_env["env_vars"].update({ + "TLLM_DISABLE_MPI": "1", + "MASTER_ADDR": "127.0.0.1", + }) + + remote_tests = [] + for rank in range(world_size): + remote_tests.append( + CpBroadcastTest.options(runtime_env=runtime_env).remote( + rank, world_size, tp_size, cp_size)) + + ray.get([test.__ray_ready__.remote() for test in remote_tests]) + + port = ray.get(remote_tests[0].setup_tcp_store.remote()) + ray.get([test.setup_distributed_env.remote(port) for test in remote_tests]) + + # Test broadcasting from root=0. + results = ray.get([ + test.run_tensor_broadcast.remote(root_tensor, root=0) + for test in remote_tests + ]) + for r in results: + assert r is True, "Tensor broadcast from root=0 failed" + + +@pytest.mark.gpu2 +@pytest.mark.parametrize("test_object", [ + { + "key1": "value1", + "key2": [1, 2, 3] + }, + ["item1", "item2", { + "nested": True + }], + "simple_string", +], + ids=["dict", "list", "string"]) +@pytest.mark.parametrize("broadcast_method", [ + "run_object_broadcast", + "run_tp_cp_broadcast", +], + ids=["cp_broadcast", "tp_cp_broadcast"]) +def test_cp_tp_broadcast_object(setup_ray_cluster, test_object, + broadcast_method): + """Test TorchDist.cp_broadcast and tp_cp_broadcast with non-tensor objects. + + This tests both cp_broadcast (for context parallelism only) and tp_cp_broadcast + (for combined TP+CP broadcast used in helix parallelism). + """ + world_size = 2 + tp_size = 1 + cp_size = 2 # Enable context parallelism. + + runtime_env = ray.runtime_env.RuntimeEnv() + runtime_env["env_vars"] = os.environ.copy() + runtime_env["env_vars"].update({ + "TLLM_DISABLE_MPI": "1", + "MASTER_ADDR": "127.0.0.1", + }) + + remote_tests = [] + for rank in range(world_size): + remote_tests.append( + CpBroadcastTest.options(runtime_env=runtime_env).remote( + rank, world_size, tp_size, cp_size)) + + ray.get([test.__ray_ready__.remote() for test in remote_tests]) + + port = ray.get(remote_tests[0].setup_tcp_store.remote()) + ray.get([test.setup_distributed_env.remote(port) for test in remote_tests]) + + # Test broadcasting object from root=0 using the specified method. + results = ray.get([ + getattr(test, broadcast_method).remote(test_object, root=0) + for test in remote_tests + ]) + for r in results: + assert r is True, f"{broadcast_method} from root=0 failed for {type(test_object)}"