Skip to content

Commit fc5ee7f

Browse files
jwfrommfacebook-github-bot
authored andcommitted
Add CPU registrations to custom operators (#3262)
Summary: Pull Request resolved: #3262 X-link: facebookresearch/FBGEMM#363 While CPU arguments shouldnt be used for custom cuda kernels, it turns out they sometimes are in production. The outputs will be garbage but doing so seems to be part of the model construction process. This small diff fixes the issue by adding CPU registrations for custom operators. This should enable production use cases without break torch.export support. Reviewed By: jaconey, jianyuh, jiawenliu64 Differential Revision: D64703788 fbshipit-source-id: c0c8cfb7f0b67c13be10f419c8e3d83991429edb
1 parent d32fc6a commit fc5ee7f

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

fbgemm_gpu/experimental/gen_ai/src/comm/car.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,4 +278,15 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
278278
m.impl("two_shot_car_allreduce", two_shot_car_allreduce);
279279
}
280280

281+
// Though it shouldnt be used, it is useful to define these functions for CPU to
282+
// accomodate model creation.
283+
TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
284+
m.impl("nccl_allreduce", nccl_allreduce);
285+
m.impl("nccl_allgather", nccl_allgather);
286+
m.impl("nccl_alltoall", nccl_alltoall);
287+
m.impl("nccl_reducescatter", nccl_reducescatter);
288+
m.impl("one_shot_car_allreduce", one_shot_car_allreduce);
289+
m.impl("two_shot_car_allreduce", two_shot_car_allreduce);
290+
}
291+
281292
} // namespace fbgemm_gpu

fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,4 +214,24 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
214214
#endif
215215
}
216216

217+
// Though it should never be used, it still seems helpful to define these
218+
// functions for CPU to accomodate model creation.
219+
TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
220+
m.impl("f8f8bf16_blockwise", f8f8bf16_blockwise);
221+
m.impl("f8f8bf16_tensorwise", f8f8bf16_tensorwise);
222+
m.impl("f8f8bf16_rowwise", f8f8bf16_rowwise);
223+
m.impl("quantize_fp8_per_tensor", quantize_fp8_per_tensor);
224+
m.impl("quantize_fp8_per_row", quantize_fp8_per_row);
225+
m.impl("quantize_fp8_per_col", quantize_fp8_per_col);
226+
#ifndef USE_ROCM
227+
m.impl("i8i8bf16", i8i8bf16);
228+
m.impl("f8f8bf16", f8f8bf16);
229+
m.impl("f8f8bf16_cublas", f8f8bf16_cublas);
230+
m.impl("f8f8bf16_rowwise_batched", f8f8bf16_rowwise_batched);
231+
m.impl("f8i4bf16_rowwise", f8i4bf16_rowwise);
232+
m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched);
233+
m.impl("bf16i4bf16_rowwise", bf16i4bf16_rowwise);
234+
#endif
235+
}
236+
217237
} // namespace fbgemm_gpu

0 commit comments

Comments
 (0)