2929except ImportError :
3030 pass
3131
32- # Global workspace for standalone allreduce and non-quant ar+rms fusion
32+ # Workspace for standalone allreduce and non-quant ar+rms fusion
3333_fi_ar_workspace = None
3434# Extra workspace for quant fusion patterns (only supported by trtllm backend)
35- # Only created if primary workspace is not already trtllm
3635_fi_ar_quant_workspace = None
3736
3837
39- def get_fi_ar_workspace ():
40- return _fi_ar_workspace
41-
42-
43- def get_fi_ar_quant_workspace ():
44- return _fi_ar_quant_workspace
45-
46-
47- def initialize_fi_ar_workspace (
38+ def _create_workspace (
39+ backend : str ,
4840 world_size : int ,
4941 rank : int ,
5042 max_token_num : int ,
5143 hidden_dim : int ,
5244 dtype : torch .dtype ,
5345 group : ProcessGroup ,
54- ) -> None :
55- """
56- Initialize the workspace if not already initialized.
57-
58- Currently, this function is called by either the AllReduceFusionPass
59- or the FlashInferAllReduce backend for standalone allreduce.
60- If the fusion pass is enabled via
61- --compilation-config.pass_config.fuse_allreduce_rms=true,
62- it will create the workspace first, and the standalone backend
63- will reuse the workspace. Otherwise, the standalone backend will
64- create the workspace.
65- """
66- global _fi_ar_workspace
67- if _fi_ar_workspace is not None :
68- return
69-
70- backend = envs .VLLM_FLASHINFER_ALLREDUCE_BACKEND
46+ ):
47+ """Create a flashinfer allreduce workspace, returning None on failure."""
7148 comm_backend = TorchDistBackend (group = group )
7249 rng_state = random .getstate ()
7350 try :
7451 random .seed (int .from_bytes (os .urandom (16 ), byteorder = "big" ))
75- _fi_ar_workspace = flashinfer_comm .create_allreduce_fusion_workspace (
52+ workspace = flashinfer_comm .create_allreduce_fusion_workspace (
7653 backend = backend ,
7754 world_size = world_size ,
7855 rank = rank ,
@@ -81,9 +58,22 @@ def initialize_fi_ar_workspace(
8158 dtype = dtype ,
8259 comm_backend = comm_backend ,
8360 )
61+ except Exception as e :
62+ if "multicast" in str (e ).lower ():
63+ logger .warning_once (
64+ "Failed to initialize FlashInfer All Reduce workspace: %s. "
65+ "This is expected on GPUs without NVSwitch (e.g., NVLink "
66+ "bridge-only or PCIe topologies)." ,
67+ e ,
68+ )
69+ else :
70+ logger .warning_once (
71+ "Failed to initialize FlashInfer All Reduce workspace: %s." ,
72+ e ,
73+ )
74+ return None
8475 finally :
8576 random .setstate (rng_state )
86- assert _fi_ar_workspace is not None
8777 logger .debug (
8878 "Initialized FlashInfer All Reduce workspace: backend=%s, "
8979 "world_size=%d, rank=%d, max_token_num=%d, hidden_dim=%d, dtype=%s" ,
@@ -94,70 +84,84 @@ def initialize_fi_ar_workspace(
9484 hidden_dim ,
9585 dtype ,
9686 )
87+ return workspace
88+
89+
90+ def get_fi_ar_workspace (
91+ world_size : int ,
92+ rank : int ,
93+ max_token_num : int ,
94+ hidden_dim : int ,
95+ dtype : torch .dtype ,
96+ group : ProcessGroup ,
97+ ):
98+ """
99+ Return the allreduce workspace for non-quant patterns, initializing if needed.
100+
101+ Used by AllReduceFusionPass (non-quant patterns) and FlashInferAllReduce
102+ for standalone allreduce. Backend is controlled by
103+ VLLM_FLASHINFER_ALLREDUCE_BACKEND env var.
104+ """
105+ global _fi_ar_workspace
106+ if _fi_ar_workspace is not None :
107+ return _fi_ar_workspace
108+
109+ backend = envs .VLLM_FLASHINFER_ALLREDUCE_BACKEND
110+
111+ # Reuse the quant workspace if it was already created with the same backend
112+ if _fi_ar_quant_workspace is not None and _fi_ar_quant_workspace .backend == backend :
113+ _fi_ar_workspace = _fi_ar_quant_workspace
114+ return _fi_ar_workspace
115+
116+ _fi_ar_workspace = _create_workspace (
117+ backend , world_size , rank , max_token_num , hidden_dim , dtype , group
118+ )
119+ return _fi_ar_workspace
97120
98121
99- def initialize_fi_ar_quant_workspace (
122+ def get_fi_ar_quant_workspace (
100123 world_size : int ,
101124 rank : int ,
102125 max_token_num : int ,
103126 hidden_dim : int ,
104127 dtype : torch .dtype ,
105128 group : ProcessGroup ,
106- ) -> None :
129+ ):
107130 """
108- Initialize the workspace used by quantization fusion patterns .
131+ Return the allreduce workspace for quant patterns, initializing if needed .
109132
110- Currently this always creates a workspace for trtllm backend as only it
111- supports quantization fusion (FP8/FP4). If the primary workspace
112- is already trtllm, the quant workspace aliases to it.
133+ Always uses trtllm backend as it is the only one supporting quantization
134+ fusion (FP8/FP4).
113135 """
114136 global _fi_ar_quant_workspace
115137 if _fi_ar_quant_workspace is not None :
116- return
138+ return _fi_ar_quant_workspace
117139
118- # If primary workspace is already trtllm, reuse it
140+ # Reuse the non-quant workspace if it was already created with trtllm
119141 if _fi_ar_workspace is not None and _fi_ar_workspace .backend == "trtllm" :
120142 _fi_ar_quant_workspace = _fi_ar_workspace
121- return
143+ return _fi_ar_quant_workspace
122144
123- comm_backend = TorchDistBackend (group = group )
124- _fi_ar_quant_workspace = flashinfer_comm .create_allreduce_fusion_workspace (
125- backend = "trtllm" ,
126- world_size = world_size ,
127- rank = rank ,
128- max_token_num = max_token_num ,
129- hidden_dim = hidden_dim ,
130- dtype = dtype ,
131- comm_backend = comm_backend ,
132- )
133- assert _fi_ar_quant_workspace is not None
134- logger .debug (
135- "Initialized FlashInfer All Reduce workspace: backend=trtllm, "
136- "world_size=%d, rank=%d, max_token_num=%d, hidden_dim=%d, dtype=%s" ,
137- world_size ,
138- rank ,
139- max_token_num ,
140- hidden_dim ,
141- dtype ,
145+ _fi_ar_quant_workspace = _create_workspace (
146+ "trtllm" , world_size , rank , max_token_num , hidden_dim , dtype , group
142147 )
148+ return _fi_ar_quant_workspace
143149
144150
145151_fi_ar_workspace_lock = threading .Lock ()
146152
147153
148154def destroy_fi_ar_workspace ():
149- global _fi_ar_workspace
150- global _fi_ar_quant_workspace
155+ global _fi_ar_workspace , _fi_ar_quant_workspace
151156 with _fi_ar_workspace_lock :
152- if (
153- _fi_ar_quant_workspace is not None
154- and _fi_ar_quant_workspace is not _fi_ar_workspace
155- ):
156- _fi_ar_quant_workspace .destroy ()
157- _fi_ar_quant_workspace = None
157+ is_alias = _fi_ar_workspace is _fi_ar_quant_workspace
158+
158159 if _fi_ar_workspace is not None :
159160 _fi_ar_workspace .destroy ()
160- _fi_ar_workspace = None
161+ if _fi_ar_quant_workspace is not None and not is_alias :
162+ _fi_ar_quant_workspace .destroy ()
163+
164+ _fi_ar_workspace = _fi_ar_quant_workspace = None
161165
162166
163167atexit .register (destroy_fi_ar_workspace )
@@ -209,29 +213,21 @@ def __init__(
209213
210214 def _ensure_workspace (self , hidden_dim : int , dtype : torch .dtype ) -> bool :
211215 """Ensure the all reduce workspace is initialized."""
212- if get_fi_ar_workspace () is not None :
213- return True
214216 if self .max_num_tokens == 0 :
215217 element_size = torch .tensor ([], dtype = dtype , device = "cpu" ).element_size ()
216218 self .max_num_tokens = self .max_workspace_size // (hidden_dim * element_size )
217- try :
218- initialize_fi_ar_workspace (
219- world_size = self .world_size ,
220- rank = self .rank ,
221- max_token_num = self .max_num_tokens ,
222- hidden_dim = hidden_dim ,
223- dtype = dtype ,
224- group = self .group ,
225- )
226- return True
227- except Exception as e :
228- logger .warning (
229- "Failed to initialize FlashInfer All Reduce workspace: %s. "
230- "FlashInfer All Reduce will be disabled." ,
231- e ,
232- )
219+ workspace = get_fi_ar_workspace (
220+ world_size = self .world_size ,
221+ rank = self .rank ,
222+ max_token_num = self .max_num_tokens ,
223+ hidden_dim = hidden_dim ,
224+ dtype = dtype ,
225+ group = self .group ,
226+ )
227+ if workspace is None :
233228 self .disabled = True
234229 return False
230+ return True
235231
236232 def should_use_fi_ar (self , input_tensor : torch .Tensor ) -> bool :
237233 if self .disabled :
@@ -257,7 +253,15 @@ def should_use_fi_ar(self, input_tensor: torch.Tensor) -> bool:
257253 return self ._ensure_workspace (hidden_dim , input_tensor .dtype )
258254
259255 def all_reduce (self , input_tensor : torch .Tensor ) -> torch .Tensor :
260- workspace = get_fi_ar_workspace ()
256+ _ , hidden_dim = input_tensor .shape
257+ workspace = get_fi_ar_workspace (
258+ world_size = self .world_size ,
259+ rank = self .rank ,
260+ max_token_num = self .max_num_tokens ,
261+ hidden_dim = hidden_dim ,
262+ dtype = input_tensor .dtype ,
263+ group = self .group ,
264+ )
261265 return flashinfer_comm .allreduce_fusion (
262266 input = input_tensor ,
263267 workspace = workspace ,
0 commit comments