forked from yandexdataschool/nlp_course
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathkernel.cpp
More file actions
22 lines (17 loc) · 851 Bytes
/
kernel.cpp
File metadata and controls
22 lines (17 loc) · 851 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
#include <torch/extension.h>
torch::Tensor int4MatmulCUDA(const torch::Tensor &A, const torch::Tensor &B);
torch::Tensor int8MatmulCUDA(const torch::Tensor &A, const torch::Tensor &B);
torch::Tensor int4Matmul(const torch::Tensor &A, const torch::Tensor &B) {
torch::checkAllContiguous("int4Matmul", {{A, "A", 0}, {B, "B", 1}});
torch::checkDeviceType("int4Matmul", {A, B}, at::DeviceType::CUDA);
return int4MatmulCUDA(A, B);
}
torch::Tensor int8Matmul(const torch::Tensor &A, const torch::Tensor &B) {
torch::checkAllContiguous("int8Matmul", {{A, "A", 0}, {B, "B", 1}});
torch::checkDeviceType("int8Matmul", {A, B}, at::DeviceType::CUDA);
return int8MatmulCUDA(A, B);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("int4_matmul", &int4Matmul, "int4 matmul (CUDA)");
m.def("int8_matmul", &int8Matmul, "int8 matmul (CUDA)");
}