1+ import socket
2+ import pytest
3+
4+ import flashinfer .comm as comm
5+
6+ import paddle
7+ import paddle .distributed as dist_pp
8+ paddle .compat .enable_torch_proxy ()
9+
10+ import os
11+ import numpy as np
12+
13+ # test parameters
14+ token_num = 128
15+ hidden_dim = 1024
16+ dtype = paddle .float16
17+ pattern_code = comm .AllReduceFusionPattern .kAllReduce
18+ layout_code = comm .QuantizationSFLayout .LINEAR
19+ launch_with_pdl = False
20+ use_oneshot = True
21+ trigger_completion_at_end = True
22+ fp32_acc = False
23+
24+ def kernel (workspace_tensor , rank , world_size ):
25+ device = f"cuda:{ rank } "
26+ message_size = token_num * hidden_dim
27+ dtype = paddle .float16
28+ # Create input data
29+ allreduce_in = paddle .randn (message_size , dtype = dtype , device = device )
30+ # allreduce_in_clone = allreduce_in.clone()
31+ all_reduce_out = paddle .zeros (message_size , dtype = dtype , device = device )
32+
33+ # Add missing required parameters
34+ residual_in = paddle .randn (message_size , dtype = dtype , device = device )
35+ residual_out = paddle .zeros (message_size , dtype = dtype , device = device )
36+ norm_out = paddle .zeros (message_size , dtype = dtype , device = device )
37+ quant_out = paddle .zeros (message_size , dtype = dtype , device = device )
38+ scale_out = paddle .zeros (message_size // 16 , dtype = dtype , device = device ) # SF_VEC_SIZE = 16
39+ rms_gamma = paddle .randn (hidden_dim , dtype = dtype , device = device )
40+ rms_eps = 1e-3
41+ scale_factor = paddle .tensor (0.5 , dtype = paddle .float32 , device = device )
42+
43+ # Run fusion operation
44+ print ("Running fusion operation..." )
45+ comm .trtllm_allreduce_fusion (
46+ allreduce_in = allreduce_in ,
47+ world_size = world_size ,
48+ world_rank = rank ,
49+ token_num = token_num ,
50+ hidden_dim = hidden_dim ,
51+ workspace_ptrs = workspace_tensor ,
52+ launch_with_pdl = launch_with_pdl ,
53+ use_oneshot = use_oneshot ,
54+ trigger_completion_at_end = trigger_completion_at_end ,
55+ fp32_acc = fp32_acc ,
56+ pattern_code = pattern_code ,
57+ allreduce_out = all_reduce_out ,
58+ residual_in = residual_in ,
59+ residual_out = residual_out ,
60+ norm_out = norm_out ,
61+ quant_out = quant_out ,
62+ scale_out = scale_out ,
63+ rms_gamma = rms_gamma ,
64+ rms_eps = rms_eps ,
65+ scale_factor = scale_factor ,
66+ layout_code = layout_code ,
67+ )
68+
69+ paddle .cuda .synchronize ()
70+
71+ return allreduce_in , all_reduce_out
72+
73+ def _run_simple_worker (world_size , rank , distributed_init_port ):
74+
75+ # Create workspace
76+ # paddle.compat.enable_torch_proxy()
77+ # Set all required environment variables
78+ os .environ ['FLAGS_selected_gpus' ] = str (rank ) # Key: set GPU ID
79+ os .environ ['PADDLE_TRAINER_ID' ] = str (rank )
80+ os .environ ['PADDLE_TRAINERS_NUM' ] = str (world_size )
81+ os .environ ['PADDLE_RANK_IN_NODE' ] = str (rank )
82+
83+ # Build endpoint list
84+ endpoints = ',' .join ([f'127.0.0.1:{ distributed_init_port + i + 10 } ' for i in range (world_size )])
85+ os .environ ['PADDLE_TRAINER_ENDPOINTS' ] = endpoints
86+ os .environ ['PADDLE_CURRENT_ENDPOINT' ] = f'127.0.0.1:{ distributed_init_port + rank + 10 } '
87+ # Set NCCL related environment variables (optional but recommended)
88+ os .environ ['FLAGS_sync_nccl_allreduce' ] = '1'
89+
90+ # Set device
91+ paddle .set_device (f"gpu:{ rank } " )
92+
93+ # Initialize distributed environment
94+ dist_pp .init_parallel_env ()
95+ group_pp = dist_pp .get_group ()
96+
97+ try :
98+ # Create workspace
99+ ipc_handles , workspace_tensor = (
100+ comm .trtllm_create_ipc_workspace_for_all_reduce_fusion (
101+ rank ,
102+ world_size ,
103+ token_num ,
104+ hidden_dim ,
105+ group = group_pp ,
106+ use_fp32_lamport = False ,
107+ )
108+ )
109+
110+ dist_pp .barrier (group = group_pp )
111+
112+ # Run fusion operation
113+ allreduce_in_clone , all_reduce_out = kernel (workspace_tensor , rank , world_size )
114+
115+ # # Calculate reference result
116+ dist_pp .all_reduce (allreduce_in_clone , group = group_pp )
117+ ref_allreduce_out = allreduce_in_clone .clone ()
118+
119+ # # Verify results
120+ tolerance = 8e-2
121+ np .testing .assert_allclose (all_reduce_out .numpy (),
122+ ref_allreduce_out .numpy (), atol = tolerance , rtol = 1e-2 )
123+
124+ print (f"Rank { rank } : Test passed!" )
125+
126+ finally :
127+ dist_pp .barrier (group = group_pp )
128+ comm .trtllm_destroy_ipc_workspace_for_all_reduce (ipc_handles , group = group_pp )
129+ dist_pp .destroy_process_group (group = group_pp )
130+
131+
132+ def get_open_port () -> int :
133+ with socket .socket (socket .AF_INET , socket .SOCK_STREAM ) as s :
134+ s .bind (("127.0.0.1" , 0 ))
135+ return s .getsockname ()[1 ]
136+
137+
138+ def test_trtllm_allreduce_fusion_simple ():
139+ # Fixed test parameters
140+ world_size = 2
141+
142+ paddle .manual_seed (42 )
143+ paddle .cuda .manual_seed_all (42 )
144+
145+ available_gpus = paddle .cuda .device_count ()
146+ if world_size > available_gpus :
147+ pytest .skip (f"Requires { world_size } GPUs, but only { available_gpus } available" )
148+
149+ procs = []
150+ distributed_init_port = get_open_port ()
151+ rank = dist_pp .get_rank ()
152+ _run_simple_worker (world_size , rank , distributed_init_port )
153+
154+ print ("Simple allreduce fusion test: passed" )
155+
156+
157+ # test cmd: python -m paddle.distributed.launch --log_dir=log --devices=0,1
158+ # ./test_torch_pp_launch.py
159+ if __name__ == "__main__" :
160+ test_trtllm_allreduce_fusion_simple ()
0 commit comments