|
4 | 4 |
|
5 | 5 | import torch |
6 | 6 | import torch.distributed as dist |
7 | | -import torch.distributed._symmetric_memory as symm_mem |
8 | 7 | import torch.multiprocessing as mp |
9 | 8 | import triton |
10 | | -from triton_kernels.distributed import convert_dp_to_ep, convert_ep_to_dp, make_expt_dict_uniform, make_expt_dict_random, make_expt_assignment |
| 9 | +from triton_kernels.distributed import convert_dp_to_ep, convert_ep_to_dp, make_expt_dict_uniform, make_expt_dict_random, make_expt_assignment, symm_mem_pool |
11 | 10 | from triton_kernels.reduce import reduce |
12 | 11 | from triton_kernels.topk import topk |
13 | 12 | from triton_kernels.matmul_ogs import matmul_ogs, RoutingData, GatherIndx, ScatterIndx |
@@ -166,99 +165,6 @@ def mixture_of_expt_epsharded(x_dp_local, l_dp_local, w_ep_local, b_ep_local, ex |
166 | 165 | return z_dp_local |
167 | 166 |
|
168 | 167 |
|
169 | | -def _capture_with_prepared_symm_mem(fn): |
170 | | - """ |
171 | | - Run `fn` once to record symmetric-memory allocations, preallocate them outside the CUDA graph, |
172 | | - and capture a CUDA graph that reuses the recorded buffers. |
173 | | - """ |
174 | | - orig_symm_empty = symm_mem.empty |
175 | | - orig_symm_rendezvous = symm_mem.rendezvous |
176 | | - recorded_empty_calls = [] |
177 | | - recorded_rendezvous_calls = [] |
178 | | - buffer_id_to_index = {} |
179 | | - |
180 | | - def recording_empty(*args, **kwargs): |
181 | | - buf = orig_symm_empty(*args, **kwargs) |
182 | | - idx = len(recorded_empty_calls) |
183 | | - buffer_id_to_index[id(buf)] = idx |
184 | | - recorded_empty_calls.append((args, dict(kwargs))) |
185 | | - return buf |
186 | | - |
187 | | - def recording_rendezvous(buf, *args, **kwargs): |
188 | | - buf_id = id(buf) |
189 | | - if buf_id not in buffer_id_to_index: |
190 | | - raise RuntimeError("symm_mem.rendezvous called on unknown buffer") |
191 | | - hdl = orig_symm_rendezvous(buf, *args, **kwargs) |
192 | | - recorded_rendezvous_calls.append((buffer_id_to_index[buf_id], args, dict(kwargs))) |
193 | | - return hdl |
194 | | - |
195 | | - symm_mem.empty = recording_empty |
196 | | - symm_mem.rendezvous = recording_rendezvous |
197 | | - try: |
198 | | - warmup_result = fn() |
199 | | - finally: |
200 | | - symm_mem.empty = orig_symm_empty |
201 | | - symm_mem.rendezvous = orig_symm_rendezvous |
202 | | - |
203 | | - prepared_empty_buffers = [orig_symm_empty(*args, **kwargs) for args, kwargs in recorded_empty_calls] |
204 | | - prepared_handles = [ |
205 | | - orig_symm_rendezvous(prepared_empty_buffers[idx], *args, **kwargs) |
206 | | - for idx, args, kwargs in recorded_rendezvous_calls |
207 | | - ] |
208 | | - |
209 | | - capture_stream = torch.cuda.Stream() |
210 | | - graph = torch.cuda.CUDAGraph() |
211 | | - |
212 | | - if recorded_empty_calls: |
213 | | - empty_idx = 0 |
214 | | - rendezvous_idx = 0 |
215 | | - |
216 | | - def reuse_empty(*args, **kwargs): |
217 | | - nonlocal empty_idx |
218 | | - if empty_idx >= len(prepared_empty_buffers): |
219 | | - raise RuntimeError("symm_mem.empty called more times than recorded") |
220 | | - expected_args, expected_kwargs = recorded_empty_calls[empty_idx] |
221 | | - if expected_args != args or expected_kwargs != kwargs: |
222 | | - raise RuntimeError("symm_mem.empty called with unexpected arguments") |
223 | | - buf = prepared_empty_buffers[empty_idx] |
224 | | - empty_idx += 1 |
225 | | - return buf |
226 | | - |
227 | | - def reuse_rendezvous(buf, *args, **kwargs): |
228 | | - nonlocal rendezvous_idx |
229 | | - if rendezvous_idx >= len(prepared_handles): |
230 | | - raise RuntimeError("symm_mem.rendezvous called more times than recorded") |
231 | | - expected_empty_idx, expected_args, expected_kwargs = recorded_rendezvous_calls[rendezvous_idx] |
232 | | - expected_buf = prepared_empty_buffers[expected_empty_idx] |
233 | | - if buf is not expected_buf: |
234 | | - raise RuntimeError("symm_mem.rendezvous received unexpected buffer") |
235 | | - if expected_args != args or expected_kwargs != kwargs: |
236 | | - raise RuntimeError("symm_mem.rendezvous called with unexpected arguments") |
237 | | - handle = prepared_handles[rendezvous_idx] |
238 | | - rendezvous_idx += 1 |
239 | | - return handle |
240 | | - |
241 | | - symm_mem.empty = reuse_empty |
242 | | - symm_mem.rendezvous = reuse_rendezvous |
243 | | - try: |
244 | | - with torch.cuda.stream(capture_stream): |
245 | | - with torch.cuda.graph(graph): |
246 | | - fn() |
247 | | - finally: |
248 | | - symm_mem.empty = orig_symm_empty |
249 | | - symm_mem.rendezvous = orig_symm_rendezvous |
250 | | - else: |
251 | | - with torch.cuda.stream(capture_stream): |
252 | | - with torch.cuda.graph(graph): |
253 | | - fn() |
254 | | - |
255 | | - # Keep references alive for as long as the graph exists. |
256 | | - graph._symm_mem_buffers = prepared_empty_buffers |
257 | | - graph._symm_mem_handles = prepared_handles |
258 | | - graph._capture_stream = capture_stream |
259 | | - return warmup_result, graph |
260 | | - |
261 | | - |
262 | 168 | def _run_expert_sharding(rank, world_size, *, n_tokens, d_model, n_expts_tot, n_expts_act, affinity_mode): |
263 | 169 | torch.manual_seed(0) |
264 | 170 |
|
@@ -303,17 +209,33 @@ def run_mixture(): |
303 | 209 | y_indx=y_indx_global, |
304 | 210 | ) |
305 | 211 |
|
306 | | - # test cuda graph capture + replay with symmetric memory |
307 | | - y_dp_local_tri, graph = _capture_with_prepared_symm_mem(run_mixture) |
| 212 | + symm_mem_pool.initialize_matmul_ogs( |
| 213 | + n_tokens_global=n_tokens_global, |
| 214 | + d_input=d_model, |
| 215 | + d_model=d_model, |
| 216 | + n_expts_act=n_expts_act, |
| 217 | + n_expts_tot=n_expts_tot, |
| 218 | + dtype=torch.bfloat16, |
| 219 | + n_ranks=world_size, |
| 220 | + group=dist.group.WORLD, |
| 221 | + device=dev, |
| 222 | + ) |
| 223 | + y_dp_local_tri = run_mixture() |
308 | 224 | y_global_tri = torch.empty_like(y_global_ref) |
309 | 225 |
|
310 | 226 | # Validate warmup run. |
311 | 227 | dist.all_gather_into_tensor(y_global_tri, y_dp_local_tri) |
312 | 228 | triton.testing.assert_close(y_global_ref, y_global_tri) |
313 | 229 |
|
314 | | - # Validate first replay with unchanged inputs. |
315 | | - graph.replay() |
316 | | - dist.all_gather_into_tensor(y_global_tri, y_dp_local_tri) |
| 230 | + # Validate cuda graph capture + replay. |
| 231 | + g = torch.cuda.CUDAGraph() |
| 232 | + stream = torch.cuda.Stream() |
| 233 | + with torch.cuda.stream(stream): |
| 234 | + with torch.cuda.graph(g): |
| 235 | + y_dp_local_tri_graph = run_mixture() |
| 236 | + |
| 237 | + g.replay() |
| 238 | + dist.all_gather_into_tensor(y_global_tri, y_dp_local_tri_graph) |
317 | 239 | triton.testing.assert_close(y_global_ref, y_global_tri) |
318 | 240 |
|
319 | 241 |
|
|
0 commit comments