Skip to content

Commit 3f79297

Browse files
committed
support all reduce fusion kernel
1 parent 5f2f442 commit 3f79297

File tree

4 files changed

+182
-8
lines changed

4 files changed

+182
-8
lines changed

flashinfer/comm/cuda_ipc.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717
import ctypes
1818
from dataclasses import dataclass
1919
from typing import Any, Dict, List, Optional
20-
20+
import paddle
21+
paddle.compat.enable_torch_proxy()
2122
import torch.distributed as dist
22-
from torch.distributed import ProcessGroup
23+
from paddle.base.core import ProcessGroup
2324

2425
# NOTE(Zihao): we should use cuda-python instead of ctypes cuda runtime bindings.
2526
# However, cuda-python's API is not stable yet, so we use ctypes bindings instead.
@@ -207,9 +208,14 @@ def create_shared_buffer(
207208
group = dist.group.WORLD
208209
world_size = dist.get_world_size(group=group)
209210
rank = dist.get_rank(group=group)
210-
handles = [None] * world_size
211-
dist.all_gather_object(handles, handle, group=group)
212-
handles = [None] * world_size
211+
# handles = [None] * world_size
212+
# dist.all_gather_object(handles, handle, group=group)
213+
# handles = [None] * world_size
214+
# dist.all_gather_object(handles, handle, group=group)
215+
216+
# The behavior of the paddle framework and torch framework is inconsistent,
217+
# so the following code is used instead
218+
handles = []
213219
dist.all_gather_object(handles, handle, group=group)
214220

215221
pointers: List[int] = []

flashinfer/comm/nvshmem_allreduce.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from typing import Optional
1818

1919
import torch
20-
from torch.distributed import ProcessGroup
20+
from paddle.base.core import ProcessGroup
2121

2222
from .nvshmem import get_nvshmem_module
2323

flashinfer/comm/trtllm_ar.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@
2020
from types import SimpleNamespace
2121
from typing import List, Optional, Tuple, Union
2222

23+
import paddle
24+
paddle.compat.enable_torch_proxy()
2325
import torch
2426
import torch.distributed as dist
25-
from torch.distributed import ProcessGroup
27+
from paddle.base.core import ProcessGroup
2628

2729
from ..jit.comm import gen_trtllm_comm_module
2830
from ..utils import register_custom_op, round_up
@@ -602,8 +604,14 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion(
602604
print(f"Rank {tp_rank} workspace[{i}] {hex(workspace[i])}")
603605

604606
# Store workspace pointers in device tensor
607+
# workspace_tensor = torch.tensor(
608+
# workspace, dtype=torch.int64, device=torch.device("cuda")
609+
# )
610+
611+
# There is a bug in the paddle framework when device="CUDA".
612+
# Currently, the bug is being avoided by changing the source code.
605613
workspace_tensor = torch.tensor(
606-
workspace, dtype=torch.int64, device=torch.device("cuda")
614+
workspace, dtype=torch.int64
607615
)
608616

609617
dist.barrier(group=group) # must sync after create_workspace
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import socket
2+
import pytest
3+
4+
import flashinfer.comm as comm
5+
6+
import paddle
7+
import paddle.distributed as dist_pp
8+
paddle.compat.enable_torch_proxy()
9+
10+
import os
11+
import numpy as np
12+
13+
# test parameters
14+
token_num = 128
15+
hidden_dim = 1024
16+
dtype = paddle.float16
17+
pattern_code = comm.AllReduceFusionPattern.kAllReduce
18+
layout_code = comm.QuantizationSFLayout.LINEAR
19+
launch_with_pdl = False
20+
use_oneshot = True
21+
trigger_completion_at_end = True
22+
fp32_acc = False
23+
24+
def kernel(workspace_tensor, rank, world_size):
25+
device = f"cuda:{rank}"
26+
message_size = token_num * hidden_dim
27+
dtype = paddle.float16
28+
# Create input data
29+
allreduce_in = paddle.randn(message_size, dtype=dtype, device=device)
30+
# allreduce_in_clone = allreduce_in.clone()
31+
all_reduce_out = paddle.zeros(message_size, dtype=dtype, device=device)
32+
33+
# Add missing required parameters
34+
residual_in = paddle.randn(message_size, dtype=dtype, device=device)
35+
residual_out = paddle.zeros(message_size, dtype=dtype, device=device)
36+
norm_out = paddle.zeros(message_size, dtype=dtype, device=device)
37+
quant_out = paddle.zeros(message_size, dtype=dtype, device=device)
38+
scale_out = paddle.zeros(message_size // 16, dtype=dtype, device=device) # SF_VEC_SIZE = 16
39+
rms_gamma = paddle.randn(hidden_dim, dtype=dtype, device=device)
40+
rms_eps = 1e-3
41+
scale_factor = paddle.tensor(0.5, dtype=paddle.float32, device=device)
42+
43+
# Run fusion operation
44+
print("Running fusion operation...")
45+
comm.trtllm_allreduce_fusion(
46+
allreduce_in=allreduce_in,
47+
world_size=world_size,
48+
world_rank=rank,
49+
token_num=token_num,
50+
hidden_dim=hidden_dim,
51+
workspace_ptrs=workspace_tensor,
52+
launch_with_pdl=launch_with_pdl,
53+
use_oneshot=use_oneshot,
54+
trigger_completion_at_end=trigger_completion_at_end,
55+
fp32_acc=fp32_acc,
56+
pattern_code=pattern_code,
57+
allreduce_out=all_reduce_out,
58+
residual_in=residual_in,
59+
residual_out=residual_out,
60+
norm_out=norm_out,
61+
quant_out=quant_out,
62+
scale_out=scale_out,
63+
rms_gamma=rms_gamma,
64+
rms_eps=rms_eps,
65+
scale_factor=scale_factor,
66+
layout_code=layout_code,
67+
)
68+
69+
paddle.cuda.synchronize()
70+
71+
return allreduce_in, all_reduce_out
72+
73+
def _run_simple_worker(world_size, rank, distributed_init_port):
74+
75+
# Create workspace
76+
# paddle.compat.enable_torch_proxy()
77+
# Set all required environment variables
78+
os.environ['FLAGS_selected_gpus'] = str(rank) # Key: set GPU ID
79+
os.environ['PADDLE_TRAINER_ID'] = str(rank)
80+
os.environ['PADDLE_TRAINERS_NUM'] = str(world_size)
81+
os.environ['PADDLE_RANK_IN_NODE'] = str(rank)
82+
83+
# Build endpoint list
84+
endpoints = ','.join([f'127.0.0.1:{distributed_init_port+i+10}' for i in range(world_size)])
85+
os.environ['PADDLE_TRAINER_ENDPOINTS'] = endpoints
86+
os.environ['PADDLE_CURRENT_ENDPOINT'] = f'127.0.0.1:{distributed_init_port+rank+10}'
87+
# Set NCCL related environment variables (optional but recommended)
88+
os.environ['FLAGS_sync_nccl_allreduce'] = '1'
89+
90+
# Set device
91+
paddle.set_device(f"gpu:{rank}")
92+
93+
# Initialize distributed environment
94+
dist_pp.init_parallel_env()
95+
group_pp = dist_pp.get_group()
96+
97+
try:
98+
# Create workspace
99+
ipc_handles, workspace_tensor = (
100+
comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
101+
rank,
102+
world_size,
103+
token_num,
104+
hidden_dim,
105+
group=group_pp,
106+
use_fp32_lamport=False,
107+
)
108+
)
109+
110+
dist_pp.barrier(group=group_pp)
111+
112+
# Run fusion operation
113+
allreduce_in_clone, all_reduce_out = kernel(workspace_tensor, rank, world_size)
114+
115+
# # Calculate reference result
116+
dist_pp.all_reduce(allreduce_in_clone, group=group_pp)
117+
ref_allreduce_out = allreduce_in_clone.clone()
118+
119+
# # Verify results
120+
tolerance = 8e-2
121+
np.testing.assert_allclose(all_reduce_out.numpy(),
122+
ref_allreduce_out.numpy(), atol=tolerance, rtol=1e-2)
123+
124+
print(f"Rank {rank}: Test passed!")
125+
126+
finally:
127+
dist_pp.barrier(group=group_pp)
128+
comm.trtllm_destroy_ipc_workspace_for_all_reduce(ipc_handles, group=group_pp)
129+
dist_pp.destroy_process_group(group=group_pp)
130+
131+
132+
def get_open_port() -> int:
133+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
134+
s.bind(("127.0.0.1", 0))
135+
return s.getsockname()[1]
136+
137+
138+
def test_trtllm_allreduce_fusion_simple():
139+
# Fixed test parameters
140+
world_size = 2
141+
142+
paddle.manual_seed(42)
143+
paddle.cuda.manual_seed_all(42)
144+
145+
available_gpus = paddle.cuda.device_count()
146+
if world_size > available_gpus:
147+
pytest.skip(f"Requires {world_size} GPUs, but only {available_gpus} available")
148+
149+
procs = []
150+
distributed_init_port = get_open_port()
151+
rank = dist_pp.get_rank()
152+
_run_simple_worker(world_size, rank, distributed_init_port)
153+
154+
print("Simple allreduce fusion test: passed")
155+
156+
157+
# test cmd: python -m paddle.distributed.launch --log_dir=log --devices=0,1
158+
# ./test_torch_pp_launch.py
159+
if __name__ == "__main__":
160+
test_trtllm_allreduce_fusion_simple()

0 commit comments

Comments
 (0)