Skip to content

Commit 80354e0

Browse files
authored
Use hand written schemas for pytorch op registration to tag mutable arguments (#17)
The mutable arguments for several ops were not being tagged. I've added schemas to the registration code so these arguments are marked properly. I've also added meta functions so the inductor can run. cc @abcdabcd987 , @nandor , @varun-sundar-rabindranath --------- Signed-off-by: Bill Nell <[email protected]>
1 parent 9d87d51 commit 80354e0

File tree

3 files changed

+104
-9
lines changed

3 files changed

+104
-9
lines changed

csrc/bindings/all_to_all_ops.cpp

Lines changed: 81 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,19 @@ void dispatch(
195195
);
196196
}
197197

198+
void fake_dispatch(
199+
fptr_t ptr,
200+
at::Tensor &outExpertNumTokens,
201+
at::Tensor &outExpertX,
202+
const std::optional<at::Tensor> &outExpertXScale,
203+
const at::Tensor &dpX,
204+
const std::optional<at::Tensor> &dpXScale,
205+
const at::Tensor &indices,
206+
const std::optional<at::Tensor> &boundM,
207+
bool doSend,
208+
bool doRecv
209+
) {}
210+
198211
template <typename Kernel, typename T, typename U>
199212
void combineImpl(
200213
Kernel *all_to_all,
@@ -297,6 +310,17 @@ void combine(
297310
}
298311
}
299312

313+
void fake_combine(
314+
fptr_t ptr,
315+
at::Tensor &outTokens,
316+
const at::Tensor &indices,
317+
const at::Tensor &weights,
318+
const at::Tensor &expertY,
319+
const std::optional<at::Tensor> &boundM,
320+
bool doSend,
321+
bool doRecv
322+
) {}
323+
300324
#undef _CHECK_TENSOR
301325

302326
} // namespace
@@ -306,11 +330,64 @@ void register_all_to_all_ops(torch::Library &m) {
306330
m.def("all_to_all_destroy", &destroy);
307331

308332
m.def("all_to_all_internode_create", &create_internode);
309-
m.def("all_to_all_internode_dispatch", &dispatch<AllToAllInterNode>);
310-
m.def("all_to_all_internode_combine", &combine<AllToAllInterNode>);
333+
334+
m.def("all_to_all_internode_dispatch("
335+
" int fptr,"
336+
" Tensor! out_expert_num_tokens,"
337+
" Tensor! out_expert_x,"
338+
" Tensor!? out_expert_x_scale,"
339+
" Tensor dp_x,"
340+
" Tensor? dp_x_scale,"
341+
" Tensor indices,"
342+
" Tensor? bound_m,"
343+
" bool do_send,"
344+
" bool do_recv"
345+
") -> ()");
346+
m.impl("all_to_all_internode_dispatch", c10::kCUDA, &dispatch<AllToAllInterNode>);
347+
m.impl("all_to_all_internode_dispatch", c10::kMeta, &fake_dispatch);
348+
349+
m.def("all_to_all_internode_combine("
350+
" int fptr,"
351+
" Tensor! out_tokens,"
352+
" Tensor indices,"
353+
" Tensor weights,"
354+
" Tensor expert_y,"
355+
" Tensor? bound_m,"
356+
" bool do_send,"
357+
" bool do_recv"
358+
") -> ()");
359+
m.impl("all_to_all_internode_combine", c10::kCUDA, &combine<AllToAllInterNode>);
360+
m.impl("all_to_all_internode_combine", c10::kMeta, &fake_combine);
311361

312362
m.def("all_to_all_intranode_create", &create_intranode);
313-
m.def("all_to_all_intranode_dispatch", &dispatch<AllToAllIntraNode>);
314-
m.def("all_to_all_intranode_combine", &combine<AllToAllIntraNode>);
363+
364+
m.def("all_to_all_intranode_dispatch("
365+
" int fptr,"
366+
" Tensor! out_expert_num_tokens,"
367+
" Tensor! out_expert_x,"
368+
" Tensor!? out_expert_x_scale,"
369+
" Tensor dp_x,"
370+
" Tensor? dp_x_scale,"
371+
" Tensor indices,"
372+
" Tensor? bound_m,"
373+
" bool do_send,"
374+
" bool do_recv"
375+
") -> ()");
376+
m.impl("all_to_all_intranode_dispatch", c10::kCUDA, &dispatch<AllToAllIntraNode>);
377+
m.impl("all_to_all_intranode_dispatch", c10::kMeta, &fake_dispatch);
378+
379+
m.def("all_to_all_intranode_combine("
380+
" int fptr,"
381+
" Tensor! out_tokens,"
382+
" Tensor indices,"
383+
" Tensor weights,"
384+
" Tensor expert_y,"
385+
" Tensor? bound_m,"
386+
" bool do_send,"
387+
" bool do_recv"
388+
") -> ()");
389+
m.impl("all_to_all_intranode_combine", c10::kCUDA, &combine<AllToAllIntraNode>);
390+
m.impl("all_to_all_intranode_combine", c10::kMeta, &fake_combine);
315391
}
392+
316393
} // namespace pplx

csrc/bindings/nvshmem.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ void alltoall(at::Tensor dest, at::Tensor source) {
7979
));
8080
}
8181

82+
void fake_alltoall(at::Tensor dest, at::Tensor source) {}
83+
8284
} // namespace
8385

8486
void pplx::register_nvshmem_ops(torch::Library &m) {
@@ -91,5 +93,7 @@ void pplx::register_nvshmem_ops(torch::Library &m) {
9193
m.def("nvshmem_malloc", &malloc_tensor);
9294
m.def("nvshmem_barrier_all", &barrier_all);
9395
m.def("nvshmem_barrier_all_on_current_stream", &barrier_all_on_current_stream);
94-
m.def("nvshmem_alltoall", &alltoall);
96+
m.def("nvshmem_alltoall(Tensor! dest, Tensor src) -> ()");
97+
m.impl("nvshmem_alltoall", c10::kCUDA, &alltoall);
98+
m.impl("nvshmem_alltoall", c10::kMeta, &fake_alltoall);
9599
}

tests/test_all_to_all.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def _do_test_all_to_all(
5050
dp_size: int,
5151
moe: MoEConfig,
5252
internode: bool,
53+
use_compile: bool,
5354
) -> None:
5455
rank = pgi.rank
5556
local_rank = pgi.local_rank
@@ -173,7 +174,10 @@ def _do_test_all_to_all(
173174
)
174175
bound_m = torch.tensor([rank_data.num_tokens], dtype=torch.uint32, device=device)
175176
logger.debug("[rank=%d] Dispatch", rank)
176-
ata.dispatch(
177+
178+
dispatch = torch.compile(ata.dispatch) if use_compile else ata.dispatch
179+
180+
dispatch(
177181
out_expert_num_tokens=expert_num_tokens,
178182
out_expert_x=expert_x,
179183
out_expert_x_scale=expert_x_scale,
@@ -184,6 +188,7 @@ def _do_test_all_to_all(
184188
indices=rank_data.indices.to(device).to(torch.uint32),
185189
bound_m=bound_m,
186190
)
191+
187192
torch.cuda.synchronize()
188193
logger.debug("[rank=%d] Dispatch done", rank)
189194

@@ -253,7 +258,10 @@ def _do_test_all_to_all(
253258
)
254259

255260
logger.debug("[rank=%d] Combine", rank)
256-
ata.combine(
261+
262+
combine = torch.compile(ata.combine) if use_compile else ata.combine
263+
264+
combine(
257265
out_tokens=y,
258266
indices=rank_data.indices.to(device).to(torch.uint32),
259267
weights=rank_data.weights.to(device),
@@ -285,6 +293,7 @@ def _worker_test_all_to_all(
285293
out_dtype: str,
286294
moe_config: MoEConfig,
287295
internode: bool,
296+
use_compile: bool = False,
288297
) -> None:
289298
uid = nvshmem_get_unique_id() if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
290299
torch.distributed.broadcast(uid, src=0)
@@ -295,7 +304,8 @@ def _worker_test_all_to_all(
295304
in_dtype=getattr(torch, in_dtype),
296305
out_dtype=getattr(torch, out_dtype),
297306
)
298-
_do_test_all_to_all(pgi, dp_size, moe_config, internode)
307+
308+
_do_test_all_to_all(pgi, dp_size, moe_config, internode, use_compile)
299309

300310
nvshmem_finalize()
301311

@@ -304,7 +314,10 @@ def _worker_test_all_to_all(
304314
@pytest.mark.parametrize("in_dtype", ["bfloat16", "float8_e4m3fn", "float16"])
305315
@pytest.mark.parametrize("out_dtype", ["float16", "bfloat16"])
306316
@pytest.mark.parametrize("internode", [True, False])
307-
def test_all_to_all_4_gpu(in_dtype: str, out_dtype: str, internode: bool) -> None:
317+
@pytest.mark.parametrize("use_compile", [False, True])
318+
def test_all_to_all_4_gpu(
319+
in_dtype: str, out_dtype: str, internode: bool, use_compile: bool
320+
) -> None:
308321
world_size = 4
309322
dp_size = 2
310323
parallel_launch(
@@ -315,6 +328,7 @@ def test_all_to_all_4_gpu(in_dtype: str, out_dtype: str, internode: bool) -> Non
315328
out_dtype,
316329
small_moe,
317330
internode,
331+
use_compile,
318332
)
319333

320334

0 commit comments

Comments
 (0)