Skip to content

Commit 0140eaf

Browse files
wzhao18rootroot
authored
[Bug] Fix FlashInfer allreduce fusion workspace uninitialized error (vllm-project#37461)
Signed-off-by: root <root@prenyx0169.a51.clusters.nvidia.com> Signed-off-by: wzhao18 <wzhao18.sz@gmail.com> Signed-off-by: <> Co-authored-by: root <root@prenyx0169.a51.clusters.nvidia.com> Co-authored-by: root <root@prenyx0042.a51.clusters.nvidia.com>
1 parent bdf6a0a commit 0140eaf

File tree

2 files changed

+128
-125
lines changed

2 files changed

+128
-125
lines changed

vllm/compilation/passes/fusion/allreduce_rms_fusion.py

Lines changed: 38 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,6 @@
8686
destroy_fi_ar_workspace,
8787
get_fi_ar_quant_workspace,
8888
get_fi_ar_workspace,
89-
initialize_fi_ar_quant_workspace,
90-
initialize_fi_ar_workspace,
9189
)
9290

9391
ar_fusion_patterns = flashinfer_comm.AllReduceFusionPattern
@@ -133,15 +131,23 @@ def call_trtllm_fused_allreduce_norm(
133131

134132
# Select workspace based on pattern: quant patterns use the
135133
# trtllm quant workspace, non-quant patterns use the primary workspace.
136-
if pattern_code in (
134+
is_quant_pattern = pattern_code in (
137135
ar_fusion_patterns.kARResidualRMSNormFP8Quant,
138136
ar_fusion_patterns.kARResidualRMSNormFP4Quant,
139-
):
140-
workspace = get_fi_ar_quant_workspace()
141-
else:
142-
workspace = get_fi_ar_workspace()
137+
)
138+
get_workspace_fn = (
139+
get_fi_ar_quant_workspace if is_quant_pattern else get_fi_ar_workspace
140+
)
141+
workspace = get_workspace_fn(
142+
world_size=world_size,
143+
rank=get_tensor_model_parallel_rank(),
144+
max_token_num=max_token_num,
145+
hidden_dim=hidden_size,
146+
dtype=allreduce_in.dtype,
147+
group=get_tp_group().device_group,
148+
)
143149
assert workspace is not None, (
144-
"Flashinfer workspace must be initialized when using flashinfer"
150+
"Flashinfer allreduce workspace must be initialized when using flashinfer"
145151
)
146152
assert flashinfer_comm is not None
147153
if norm_out is None:
@@ -753,35 +759,29 @@ def __init__(self, config: VllmConfig) -> None:
753759
scope="global",
754760
)
755761

756-
for workspace_init_fn in [
757-
initialize_fi_ar_workspace,
758-
initialize_fi_ar_quant_workspace,
759-
]:
760-
try:
761-
workspace_init_fn(
762-
world_size=self.tp_size,
763-
rank=rank,
764-
max_token_num=self.max_token_num,
765-
hidden_dim=self.hidden_dim,
766-
dtype=self.model_dtype,
767-
group=self.group,
768-
)
769-
except Exception as e:
770-
if "multicast" in str(e).lower():
771-
logger.warning(
772-
"AllReduce fusion pass is disabled: flashinfer workspace "
773-
"creation failed: %s. This is expected on GPUs without "
774-
"NVSwitch (e.g., NVLink bridge-only or PCIe topologies). "
775-
"Falling back to non-fused allreduce.",
776-
str(e),
777-
)
778-
else:
779-
logger.warning(
780-
"Failed to initialize FlashInfer All Reduce workspace: %s. "
781-
"AllReduce fusion pass will be disabled.",
782-
e,
783-
)
784-
return
762+
workspace_kwargs = dict(
763+
world_size=self.tp_size,
764+
rank=rank,
765+
max_token_num=self.max_token_num,
766+
hidden_dim=self.hidden_dim,
767+
dtype=self.model_dtype,
768+
group=self.group,
769+
)
770+
if get_fi_ar_workspace(**workspace_kwargs) is None:
771+
logger.warning_once(
772+
"Failed to initialize Flashinfer allreduce workspace. "
773+
"Flashinfer allreduce-norm fusion will be disabled."
774+
)
775+
return
776+
777+
self.supports_quant_fusion = (
778+
get_fi_ar_quant_workspace(**workspace_kwargs) is not None
779+
)
780+
if not self.supports_quant_fusion:
781+
logger.warning_once(
782+
"Failed to initialize Flashinfer allreduce workspace. "
783+
"Flashinfer allreduce-norm-quant fusion will be disabled."
784+
)
785785

786786
self.allreduce_params = FlashInferFusedAllReduceParams(
787787
world_size=self.tp_size,
@@ -793,9 +793,8 @@ def __init__(self, config: VllmConfig) -> None:
793793

794794
@enable_fake_mode
795795
def register_patterns(self) -> None:
796-
supports_quantization = get_fi_ar_quant_workspace() is not None
797796
for epsilon in [1e-5, 1e-6]:
798-
if supports_quantization:
797+
if self.supports_quant_fusion:
799798
AllReduceFusedRMSNormStaticQuantFP8Pattern(
800799
epsilon,
801800
self.model_dtype,

vllm/distributed/device_communicators/flashinfer_all_reduce.py

Lines changed: 90 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -29,50 +29,27 @@
2929
except ImportError:
3030
pass
3131

32-
# Global workspace for standalone allreduce and non-quant ar+rms fusion
32+
# Workspace for standalone allreduce and non-quant ar+rms fusion
3333
_fi_ar_workspace = None
3434
# Extra workspace for quant fusion patterns (only supported by trtllm backend)
35-
# Only created if primary workspace is not already trtllm
3635
_fi_ar_quant_workspace = None
3736

3837

39-
def get_fi_ar_workspace():
40-
return _fi_ar_workspace
41-
42-
43-
def get_fi_ar_quant_workspace():
44-
return _fi_ar_quant_workspace
45-
46-
47-
def initialize_fi_ar_workspace(
38+
def _create_workspace(
39+
backend: str,
4840
world_size: int,
4941
rank: int,
5042
max_token_num: int,
5143
hidden_dim: int,
5244
dtype: torch.dtype,
5345
group: ProcessGroup,
54-
) -> None:
55-
"""
56-
Initialize the workspace if not already initialized.
57-
58-
Currently, this function is called by either the AllReduceFusionPass
59-
or the FlashInferAllReduce backend for standalone allreduce.
60-
If the fusion pass is enabled via
61-
--compilation-config.pass_config.fuse_allreduce_rms=true,
62-
it will create the workspace first, and the standalone backend
63-
will reuse the workspace. Otherwise, the standalone backend will
64-
create the workspace.
65-
"""
66-
global _fi_ar_workspace
67-
if _fi_ar_workspace is not None:
68-
return
69-
70-
backend = envs.VLLM_FLASHINFER_ALLREDUCE_BACKEND
46+
):
47+
"""Create a flashinfer allreduce workspace, returning None on failure."""
7148
comm_backend = TorchDistBackend(group=group)
7249
rng_state = random.getstate()
7350
try:
7451
random.seed(int.from_bytes(os.urandom(16), byteorder="big"))
75-
_fi_ar_workspace = flashinfer_comm.create_allreduce_fusion_workspace(
52+
workspace = flashinfer_comm.create_allreduce_fusion_workspace(
7653
backend=backend,
7754
world_size=world_size,
7855
rank=rank,
@@ -81,9 +58,22 @@ def initialize_fi_ar_workspace(
8158
dtype=dtype,
8259
comm_backend=comm_backend,
8360
)
61+
except Exception as e:
62+
if "multicast" in str(e).lower():
63+
logger.warning_once(
64+
"Failed to initialize FlashInfer All Reduce workspace: %s. "
65+
"This is expected on GPUs without NVSwitch (e.g., NVLink "
66+
"bridge-only or PCIe topologies).",
67+
e,
68+
)
69+
else:
70+
logger.warning_once(
71+
"Failed to initialize FlashInfer All Reduce workspace: %s.",
72+
e,
73+
)
74+
return None
8475
finally:
8576
random.setstate(rng_state)
86-
assert _fi_ar_workspace is not None
8777
logger.debug(
8878
"Initialized FlashInfer All Reduce workspace: backend=%s, "
8979
"world_size=%d, rank=%d, max_token_num=%d, hidden_dim=%d, dtype=%s",
@@ -94,70 +84,84 @@ def initialize_fi_ar_workspace(
9484
hidden_dim,
9585
dtype,
9686
)
87+
return workspace
88+
89+
90+
def get_fi_ar_workspace(
91+
world_size: int,
92+
rank: int,
93+
max_token_num: int,
94+
hidden_dim: int,
95+
dtype: torch.dtype,
96+
group: ProcessGroup,
97+
):
98+
"""
99+
Return the allreduce workspace for non-quant patterns, initializing if needed.
100+
101+
Used by AllReduceFusionPass (non-quant patterns) and FlashInferAllReduce
102+
for standalone allreduce. Backend is controlled by
103+
VLLM_FLASHINFER_ALLREDUCE_BACKEND env var.
104+
"""
105+
global _fi_ar_workspace
106+
if _fi_ar_workspace is not None:
107+
return _fi_ar_workspace
108+
109+
backend = envs.VLLM_FLASHINFER_ALLREDUCE_BACKEND
110+
111+
# Reuse the quant workspace if it was already created with the same backend
112+
if _fi_ar_quant_workspace is not None and _fi_ar_quant_workspace.backend == backend:
113+
_fi_ar_workspace = _fi_ar_quant_workspace
114+
return _fi_ar_workspace
115+
116+
_fi_ar_workspace = _create_workspace(
117+
backend, world_size, rank, max_token_num, hidden_dim, dtype, group
118+
)
119+
return _fi_ar_workspace
97120

98121

99-
def initialize_fi_ar_quant_workspace(
122+
def get_fi_ar_quant_workspace(
100123
world_size: int,
101124
rank: int,
102125
max_token_num: int,
103126
hidden_dim: int,
104127
dtype: torch.dtype,
105128
group: ProcessGroup,
106-
) -> None:
129+
):
107130
"""
108-
Initialize the workspace used by quantization fusion patterns.
131+
Return the allreduce workspace for quant patterns, initializing if needed.
109132
110-
Currently this always creates a workspace for trtllm backend as only it
111-
supports quantization fusion (FP8/FP4). If the primary workspace
112-
is already trtllm, the quant workspace aliases to it.
133+
Always uses trtllm backend as it is the only one supporting quantization
134+
fusion (FP8/FP4).
113135
"""
114136
global _fi_ar_quant_workspace
115137
if _fi_ar_quant_workspace is not None:
116-
return
138+
return _fi_ar_quant_workspace
117139

118-
# If primary workspace is already trtllm, reuse it
140+
# Reuse the non-quant workspace if it was already created with trtllm
119141
if _fi_ar_workspace is not None and _fi_ar_workspace.backend == "trtllm":
120142
_fi_ar_quant_workspace = _fi_ar_workspace
121-
return
143+
return _fi_ar_quant_workspace
122144

123-
comm_backend = TorchDistBackend(group=group)
124-
_fi_ar_quant_workspace = flashinfer_comm.create_allreduce_fusion_workspace(
125-
backend="trtllm",
126-
world_size=world_size,
127-
rank=rank,
128-
max_token_num=max_token_num,
129-
hidden_dim=hidden_dim,
130-
dtype=dtype,
131-
comm_backend=comm_backend,
132-
)
133-
assert _fi_ar_quant_workspace is not None
134-
logger.debug(
135-
"Initialized FlashInfer All Reduce workspace: backend=trtllm, "
136-
"world_size=%d, rank=%d, max_token_num=%d, hidden_dim=%d, dtype=%s",
137-
world_size,
138-
rank,
139-
max_token_num,
140-
hidden_dim,
141-
dtype,
145+
_fi_ar_quant_workspace = _create_workspace(
146+
"trtllm", world_size, rank, max_token_num, hidden_dim, dtype, group
142147
)
148+
return _fi_ar_quant_workspace
143149

144150

145151
_fi_ar_workspace_lock = threading.Lock()
146152

147153

148154
def destroy_fi_ar_workspace():
149-
global _fi_ar_workspace
150-
global _fi_ar_quant_workspace
155+
global _fi_ar_workspace, _fi_ar_quant_workspace
151156
with _fi_ar_workspace_lock:
152-
if (
153-
_fi_ar_quant_workspace is not None
154-
and _fi_ar_quant_workspace is not _fi_ar_workspace
155-
):
156-
_fi_ar_quant_workspace.destroy()
157-
_fi_ar_quant_workspace = None
157+
is_alias = _fi_ar_workspace is _fi_ar_quant_workspace
158+
158159
if _fi_ar_workspace is not None:
159160
_fi_ar_workspace.destroy()
160-
_fi_ar_workspace = None
161+
if _fi_ar_quant_workspace is not None and not is_alias:
162+
_fi_ar_quant_workspace.destroy()
163+
164+
_fi_ar_workspace = _fi_ar_quant_workspace = None
161165

162166

163167
atexit.register(destroy_fi_ar_workspace)
@@ -209,29 +213,21 @@ def __init__(
209213

210214
def _ensure_workspace(self, hidden_dim: int, dtype: torch.dtype) -> bool:
211215
"""Ensure the all reduce workspace is initialized."""
212-
if get_fi_ar_workspace() is not None:
213-
return True
214216
if self.max_num_tokens == 0:
215217
element_size = torch.tensor([], dtype=dtype, device="cpu").element_size()
216218
self.max_num_tokens = self.max_workspace_size // (hidden_dim * element_size)
217-
try:
218-
initialize_fi_ar_workspace(
219-
world_size=self.world_size,
220-
rank=self.rank,
221-
max_token_num=self.max_num_tokens,
222-
hidden_dim=hidden_dim,
223-
dtype=dtype,
224-
group=self.group,
225-
)
226-
return True
227-
except Exception as e:
228-
logger.warning(
229-
"Failed to initialize FlashInfer All Reduce workspace: %s. "
230-
"FlashInfer All Reduce will be disabled.",
231-
e,
232-
)
219+
workspace = get_fi_ar_workspace(
220+
world_size=self.world_size,
221+
rank=self.rank,
222+
max_token_num=self.max_num_tokens,
223+
hidden_dim=hidden_dim,
224+
dtype=dtype,
225+
group=self.group,
226+
)
227+
if workspace is None:
233228
self.disabled = True
234229
return False
230+
return True
235231

236232
def should_use_fi_ar(self, input_tensor: torch.Tensor) -> bool:
237233
if self.disabled:
@@ -257,7 +253,15 @@ def should_use_fi_ar(self, input_tensor: torch.Tensor) -> bool:
257253
return self._ensure_workspace(hidden_dim, input_tensor.dtype)
258254

259255
def all_reduce(self, input_tensor: torch.Tensor) -> torch.Tensor:
260-
workspace = get_fi_ar_workspace()
256+
_, hidden_dim = input_tensor.shape
257+
workspace = get_fi_ar_workspace(
258+
world_size=self.world_size,
259+
rank=self.rank,
260+
max_token_num=self.max_num_tokens,
261+
hidden_dim=hidden_dim,
262+
dtype=input_tensor.dtype,
263+
group=self.group,
264+
)
261265
return flashinfer_comm.allreduce_fusion(
262266
input=input_tensor,
263267
workspace=workspace,

0 commit comments

Comments
 (0)