Skip to content

Commit 6c670dd

Browse files
cyang49njhill
andcommitted
Adding exllamav2 support for GPTQ models
This PR adds exllamav2 kernels. The added changes are adapted from two open source repositories: - https://github.com/turboderp/exllamav2 - https://github.com/PanQiWei/AutoGPTQ Co-authored-by: Nick Hill <[email protected]>
1 parent 27e0952 commit 6c670dd

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+6574
-28
lines changed

Dockerfile

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,13 @@ WORKDIR /usr/src
222222
COPY server/exllama_kernels/ .
223223
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build
224224

225+
## Build transformers exllamav2 kernels ########################################
226+
FROM python-builder as exllamav2-kernels-builder
227+
228+
WORKDIR /usr/src
229+
230+
COPY server/exllamav2_kernels/ .
231+
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build
225232

226233
## Flash attention cached build image ##########################################
227234
FROM base as flash-att-cache
@@ -262,6 +269,9 @@ COPY --from=flash-att-v2-cache /usr/src/flash-attention-v2/build/lib.linux-x86_6
262269
# Copy build artifacts from exllama kernels builder
263270
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-* ${SITE_PACKAGES}
264271

272+
# Copy build artifacts from exllamav2 kernels builder
273+
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-* ${SITE_PACKAGES}
274+
265275
# Install server
266276
COPY proto proto
267277
COPY server server
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#ifndef _config_h
2+
#define _config_h
3+
4+
#define MAX_Q_GEMM_ROWS 50
5+
6+
#define QMODE_2BIT 1
7+
#define QMODE_3BIT 1
8+
#define QMODE_4BIT 1
9+
#define QMODE_5BIT 1
10+
#define QMODE_6BIT 0
11+
#define QMODE_8BIT 0
12+
13+
#endif
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#include "quantize_func.h"
2+
#include "../cuda/quantize.cuh"
3+
4+
void quantize_range
5+
(
6+
torch::Tensor quant,
7+
torch::Tensor scale,
8+
torch::Tensor out_q,
9+
float qzero,
10+
float maxq,
11+
torch::Tensor hessian_inv,
12+
torch::Tensor weights,
13+
torch::Tensor error,
14+
int a,
15+
int b
16+
)
17+
{
18+
int columns = weights.size(1);
19+
int hcolumns = hessian_inv.size(1);
20+
21+
for (int c = a; c < b; c++)
22+
{
23+
quantize_cuda
24+
(
25+
((const float*) weights.data_ptr()) + c * columns,
26+
((float*) quant.data_ptr()) + c * columns,
27+
(const float*) scale.data_ptr(),
28+
out_q.device().is_meta() ? NULL : ((uint16_t*) out_q.data_ptr()) + c * columns,
29+
1,
30+
columns,
31+
qzero,
32+
maxq
33+
);
34+
35+
adjust_error_row_cuda
36+
(
37+
(const float*) hessian_inv.data_ptr(),
38+
(float*) error.data_ptr(),
39+
(const float*) weights.data_ptr(),
40+
(const float*) quant.data_ptr(),
41+
c,
42+
columns,
43+
hcolumns
44+
);
45+
46+
vv_mul_sub_cuda
47+
(
48+
((const float*) hessian_inv.data_ptr()) + c * hcolumns + c,
49+
((const float*) error.data_ptr()) + c * columns,
50+
((float*) weights.data_ptr()) + c * columns,
51+
b - c,
52+
columns
53+
);
54+
}
55+
56+
torch::Tensor x = hessian_inv.slice(0, a, b).slice(1, b).transpose(0, 1);
57+
torch::Tensor y = error.slice(0, a, b);
58+
weights.slice(0, b).addmm_(x, y, 1.0f, -1.0f);
59+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#ifndef _quantize_func_h
2+
#define _quantize_func_h
3+
4+
#include <torch/extension.h>
5+
#include <cuda_runtime.h>
6+
#include <cuda_fp16.h>
7+
#include <ATen/cuda/CUDAContext.h>
8+
#include <cstdint>
9+
#include <cstdio>
10+
11+
void quantize_range
12+
(
13+
torch::Tensor quant,
14+
torch::Tensor scale,
15+
torch::Tensor out_q,
16+
float qzero,
17+
float maxq,
18+
torch::Tensor hessian_inv,
19+
torch::Tensor weights,
20+
torch::Tensor error,
21+
int a,
22+
int b
23+
);
24+
25+
#endif

0 commit comments

Comments
 (0)