Skip to content

Commit 955ba64

Browse files
Optimization for quantized gemm skinny sizes (ROCm#411)
* Optimization for quantized gemm skinny sizes * lint fix * Add support for bf16/fp16 * code cleanup * code cleanup * lint fix2 * cleanup * Moved the logic into tuned gemm to preserve API compatibility --------- Co-authored-by: Gregory Shtrasberg <[email protected]> Co-authored-by: Gregory Shtrasberg <[email protected]>
1 parent 17b26bd commit 955ba64

File tree

7 files changed

+559
-52
lines changed

7 files changed

+559
-52
lines changed

csrc/rocm/custom.cu

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,24 @@ void wvSpltK(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
4848
at::cuda::getCurrentCUDAStream(), CuCount);
4949
}
5050

51+
void wvSpltKQ_(void* in_a, void* in_b, void* out_c, void* scale_a,
52+
void* scale_b, const int M, const int K, const int Kp,
53+
const int N, const int Otp_in, cudaStream_t stream,
54+
const int CuCount);
55+
56+
void wvSpltKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
57+
at::Tensor& scale_a, at::Tensor& scale_b, const int64_t N_in,
58+
const int64_t Otp_in, const int64_t CuCount) {
59+
auto M = in_a.size(0);
60+
auto K = in_a.size(1);
61+
auto Kp = in_a.stride(0);
62+
int N = N_in;
63+
int Otp = Otp_in;
64+
wvSpltKQ_(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(),
65+
scale_a.data_ptr(), scale_b.data_ptr(), M, K, Kp, N, Otp,
66+
at::cuda::getCurrentCUDAStream(), CuCount);
67+
}
68+
5169
void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K,
5270
cudaStream_t stream, const int solidx);
5371

0 commit comments

Comments
 (0)