Skip to content

Commit 316ca8d

Browse files
cyang49njhill
andauthored
Re: Incoporate Marlin for GPTQ checkpoints into tgis_native (#66)
Resubmitting Marlin PR due to accidental removal #### Motivation This PR enables the use of Marlin kernel for GPTQ checkpoints. Marlin is shown to outperform Exllamav2 on Nvidia GPUs, especially for larger batch sizes. #### Modifications The code changes are mostly similar to exllamav2, except that it uses the Marlin kernel code and binding from the AutoGPTQ package instead of sourcing a separate marlin package. I adapted the QuantLinear implementation from AutoGPTQ with changes to remove codes that we don't need. Note that, my changes also enable marlin support for checkpoints that uses activation reordering (`desc_act=True`). Marlin can be turned on by setting environment variable `GPTQ_CUDA_TYPE=marlin`. Note that Marlin kernel only works on Nvidia GPUs with compute capability >= 8.0. #### Result ``` [Llama-70B-4bit-128g] Single A100x80GB, 1k context, output 512 tokens, batch size=16, Marlin Prefill : 12.2s, Inference time:38.57s Exllamav2 Prefill : 9.68s, Inference time:79.7s ``` - Investigations are needed as Marlin prefill appears slower. The code needs to be more thoroughly tested both for the performance and correctness in the following scenarios: - Should not break fp16 logic - Should work for `desc_act=False` GPTQ checkpoints correctly with optimal performance - Should work for `desc_act=True` GPTQ checkpoints correctly with optimal performance, with slightly worse performance than the previous scenario - Should not break TP uses, although TP performance still needs further optimizations - Memory management needs extensive reviews #### Related Issues #51 --------- Signed-off-by: Chih-Chieh-Yang <[email protected]> Signed-off-by: cyang49 <[email protected]> Co-authored-by: Nick Hill <[email protected]>
1 parent 91a9072 commit 316ca8d

File tree

5 files changed

+243
-43
lines changed

5 files changed

+243
-43
lines changed

server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ def _load_multi_mqa_gptq(
6868
g_idx = g_idx.to(device=weights.device)
6969
bits, groupsize = weights._get_gptq_params()
7070

71-
from text_generation_server.utils.layers import HAS_EXLLAMA
72-
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, HAS_EXLLAMA)
71+
from text_generation_server.utils.layers import HAS_GPTQ_CUDA
72+
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, HAS_GPTQ_CUDA)
7373

7474
if bias:
7575
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")

server/text_generation_server/server.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -277,29 +277,28 @@ async def serve_inner(
277277
print(model.config.__str__())
278278

279279
if quantize == "gptq" and deployment_framework == "tgis_native":
280-
from text_generation_server.utils.layers import HAS_EXLLAMA, EXLLAMA_VERSION
281-
if HAS_EXLLAMA:
280+
from text_generation_server.utils.layers import HAS_GPTQ_CUDA, EXLLAMA_VERSION
281+
if HAS_GPTQ_CUDA and EXLLAMA_VERSION is not None:
282282
try:
283283
# When using GPTQ, Exllama kernels need some global kernels
284284
# For which we have the final shapes only after the model has loaded
285285
# This will allocate those buffers.
286-
287286
if EXLLAMA_VERSION == "1":
288287
from text_generation_server.utils.gptq.exllama import (
289288
create_exllama_buffers, set_device,
290289
)
291290
set_device(device)
292291
create_exllama_buffers(max_sequence_length)
293-
else:
294-
assert EXLLAMA_VERSION == "2"
292+
elif EXLLAMA_VERSION == "2":
295293
from text_generation_server.utils.gptq.exllamav2 import (
296294
set_device, Ex4bitLinearV2,
297295
)
298296
set_device(device)
299297
for _, submodule in model.model.named_modules():
300298
if isinstance(submodule, Ex4bitLinearV2):
301299
submodule.post_init() # make q matrix and set scratch space
302-
300+
else:
301+
raise ValueError(f"Unsupported {EXLLAMA_VERSION=}")
303302
except ImportError:
304303
print("WARN: Error setting up GPTQ exllama buffers")
305304

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
# Adapted from https://github.com/AutoGPTQ/AutoGPTQ/blob/main/auto_gptq/nn_modules/qlinear/qlinear_marlin.py
2+
3+
import numpy as np
4+
import torch
5+
import torch.nn as nn
6+
7+
try:
8+
import autogptq_marlin_cuda
9+
except ImportError as e:
10+
marlin_import_exception = e
11+
12+
def error_raiser_marlin(*args, **kwargs):
13+
raise ValueError(
14+
f"Trying to use the marlin backend, but could not import the C++/CUDA dependencies with the following error: {marlin_import_exception}"
15+
)
16+
17+
autogptq_marlin_cuda = error_raiser_marlin
18+
19+
20+
def mul(A, B, C, s, workspace, thread_k=-1, thread_n=-1, sms=-1, max_par=16):
21+
"""Marlin FP16xINT4 multiply; can be used within `torch.compile`.
22+
@A: `torch.half` input matrix of shape `(m, k)` in standard row-major layout
23+
@B: `torch.int` weight matrix of original shape `(k, n)` in Marlin format; see `Layer.pack()`
24+
@C: `torch.half` out matrix of shape `(m, n)` in standard row-major layout
25+
@s: `torch.half` scales of shape `(m / group_size, n)`
26+
@workspace: `torch.int` tensor with at least `n / 128 * max_par` entries that are all zero
27+
@thread_k: `k` size of a thread_tile in `B` (can usually be left as auto -1)
28+
@thread_n: `n` size of a thread_tile in `B` (can usually be left as auto -1)
29+
@sms: number of SMs to use for the kernel (can usually be left as auto -1)
30+
@max_par: maximum number of batch 64 problems to solve in parallel for large input sizes
31+
"""
32+
autogptq_marlin_cuda.mul(A, B, C, s, workspace, thread_k, thread_n, sms, max_par)
33+
34+
35+
# Precompute permutations for Marlin weight and scale shuffling
36+
37+
38+
def _get_perms():
39+
perm = []
40+
for i in range(32):
41+
perm1 = []
42+
col = i // 4
43+
for block in [0, 1]:
44+
for row in [
45+
2 * (i % 4),
46+
2 * (i % 4) + 1,
47+
2 * (i % 4 + 4),
48+
2 * (i % 4 + 4) + 1,
49+
]:
50+
perm1.append(16 * row + col + 8 * block)
51+
for j in range(4):
52+
perm.extend([p + 256 * j for p in perm1])
53+
54+
perm = np.array(perm)
55+
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
56+
perm = perm.reshape((-1, 8))[:, interleave].ravel()
57+
perm = torch.from_numpy(perm)
58+
scale_perm = []
59+
for i in range(8):
60+
scale_perm.extend([i + 8 * j for j in range(8)])
61+
scale_perm_single = []
62+
for i in range(4):
63+
scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
64+
return perm, scale_perm, scale_perm_single
65+
66+
# _perm, _scale_perm, _scale_perm_single = _get_perms()
67+
68+
# def unpack_qzeros(qzeros):
69+
# unpacked_zeros = torch.zeros(
70+
# (qzeros.shape[0], qzeros.shape[1] * 8),
71+
# dtype=torch.int8,
72+
# device=qzeros.device,
73+
# requires_grad=False,
74+
# )
75+
76+
# for col in range(unpacked_zeros.shape[1]):
77+
# i = col % 8
78+
# unpacked_zeros[:, col] = (qzeros[:, col // 8] >> (4 * i)) & 0xF
79+
80+
# return unpacked_zeros + 1
81+
82+
def pack(x, nbits=4):
83+
pack_size = 32 // nbits
84+
out = torch.zeros((x.shape[0]//pack_size, x.shape[1]), dtype=x.dtype, device=x.device)
85+
bitmask = 2**nbits - 1
86+
for i in range(pack_size):
87+
out |= (x[i::pack_size] & bitmask) << (nbits*i)
88+
return out
89+
90+
def unpack(x, nbits=4, axis=0):
91+
assert nbits == 4
92+
bitmask = 2**nbits - 1
93+
pack_size = 32 // nbits
94+
dim0_size = x.shape[0] * pack_size if axis == 0 else x.shape[0]
95+
dim1_size = x.shape[1] * pack_size if axis == 1 else x.shape[1]
96+
output = torch.empty((dim0_size, dim1_size), dtype=x.dtype, layout=x.layout, device=x.device)
97+
98+
if axis == 0:
99+
for i in range(pack_size):
100+
output[i::pack_size, :] = (x >> (i*nbits)) & bitmask
101+
elif axis == 1:
102+
for i in range(pack_size):
103+
output[:, i::pack_size] = (x >> (i*nbits)) & bitmask
104+
else:
105+
assert False, "invalid unpack axis"
106+
return output
107+
108+
109+
class MarlinQuantLinear(nn.Module):
110+
QUANT_TYPE = "marlin"
111+
112+
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, group_size):
113+
super().__init__()
114+
115+
pack_size = 32 // bits
116+
infeatures = qweight.shape[0] * pack_size
117+
outfeatures = qweight.shape[1]
118+
119+
device_capability = torch.cuda.get_device_capability()
120+
if not device_capability[0] >= 8:
121+
raise ValueError(f'Can not use Marlin int4*fp16 kernel with a device of compute capability {device_capability}.')
122+
if infeatures % 128 != 0 or outfeatures % 256 != 0:
123+
raise ValueError("`infeatures` must be divisible by 128 and `outfeatures` by 256.")
124+
if bits not in [4]:
125+
raise NotImplementedError("Only 4 bits are supported.")
126+
if group_size not in [-1, 128] and group_size != infeatures:
127+
raise ValueError("Only group_size -1 and 128 are supported.")
128+
if infeatures % group_size != 0:
129+
raise ValueError("`infeatures` must be divisible by `group_size`.")
130+
131+
self.infeatures = infeatures
132+
self.outfeatures = outfeatures
133+
self.group_size = group_size if group_size != -1 else infeatures
134+
135+
self.desc_act = not ( g_idx is None
136+
or torch.equal(g_idx, torch.arange(infeatures, device=qweight.device) // group_size) )
137+
138+
if self.desc_act:
139+
# shuffle weight rows
140+
self.perm = torch.argsort(g_idx)
141+
# unpack --> shuffle --> pack
142+
qweight = pack(unpack(qweight)[self.perm])
143+
144+
# Repack into marlin format
145+
self.B = autogptq_marlin_cuda.gptq_repack(qweight)
146+
147+
# # Check symmetric quantization, very slow, skipping for now
148+
# dequantized_qzeros = unpack_qzeros(qzeros)
149+
# if not torch.all(dequantized_qzeros == 8):
150+
# raise ValueError(
151+
# "Marlin kernel is compatible only with checkpoints using symetric quantization. "
152+
# "Found non-symmetric quantization for the weight {name}."
153+
# )
154+
155+
# Process scales
156+
_, _scale_perm, _scale_perm_single = _get_perms()
157+
s = scales.data.clone()
158+
if group_size != infeatures:
159+
s = s.reshape((1, -1))
160+
s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm]
161+
else:
162+
s = s.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single]
163+
s = s.reshape((-1, outfeatures)).contiguous()
164+
self.s = s
165+
166+
# TODO: Can the workspace be shared among all marlin invocations?
167+
self.workspace = torch.zeros(self.outfeatures // 128 * 16, dtype=torch.int, device=qweight.device)
168+
self.bias = bias if bias is not None else None
169+
170+
def post_init(self):
171+
pass
172+
173+
def forward(self, A):
174+
A = A.half()
175+
#Support activation reordering
176+
if self.desc_act:
177+
A = A[:, self.perm]
178+
C = torch.empty(A.shape[:-1] + (self.s.shape[1],), dtype=A.dtype, device=A.device)
179+
mul(
180+
A.view((-1, A.shape[-1])),
181+
self.B,
182+
C.view((-1, C.shape[-1])),
183+
self.s,
184+
self.workspace,
185+
)
186+
C = C + self.bias if self.bias is not None else C
187+
return C

server/text_generation_server/utils/layers.py

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import os
2+
from enum import Enum
3+
24
import torch
35
import torch.distributed
46

@@ -11,8 +13,10 @@
1113
from accelerate import init_empty_weights
1214

1315
HAS_BITS_AND_BYTES = False
14-
HAS_EXLLAMA = False
1516
EXLLAMA_VERSION = None
17+
HAS_GPTQ_CUDA = False
18+
GPTQ_CUDA_TYPE = os.getenv("GPTQ_CUDA_TYPE", "exllama").lower()
19+
GPTQ_CUDA_LINEAR = None
1620

1721
if torch.cuda.is_available():
1822
try:
@@ -24,25 +28,35 @@
2428

2529
from text_generation_server.utils.gptq.quant_linear import QuantLinear
2630

27-
if os.getenv("DISABLE_EXLLAMA", "False").lower() != "true":
28-
try:
29-
EXLLAMA_VERSION = os.getenv("EXLLAMA_VERSION", "2") # Use v2 as default
30-
if EXLLAMA_VERSION == "1":
31-
from text_generation_server.utils.gptq.exllama import Ex4bitLinear as ExllamaQuantLinear
32-
elif EXLLAMA_VERSION == "2":
33-
from text_generation_server.utils.gptq.exllamav2 import Ex4bitLinearV2 as ExllamaQuantLinear
34-
else:
35-
raise ValueError(f"Unsupported value for EXLLAMA_VERSION: {EXLLAMA_VERSION}")
36-
HAS_EXLLAMA = True
37-
except ImportError as e:
38-
print_rank_n(f"Error importing ExllamaV{EXLLAMA_VERSION} kernels: {e}")
39-
EXLLAMA_VERSION = None
31+
if os.getenv("DISABLE_EXLLAMA", "False").lower() != "true": # Turn off all GPTQ CUDA kernels if set to true
32+
if GPTQ_CUDA_TYPE == "exllama":
33+
try:
34+
EXLLAMA_VERSION = os.getenv("EXLLAMA_VERSION", "2") # Use v2 as default
35+
if EXLLAMA_VERSION == "1": # TODO: consider removing v1 kernel
36+
from text_generation_server.utils.gptq.exllama import Ex4bitLinear as ExllamaQuantLinear
37+
elif EXLLAMA_VERSION == "2":
38+
from text_generation_server.utils.gptq.exllamav2 import Ex4bitLinearV2 as ExllamaQuantLinear
39+
else:
40+
raise ValueError(f"Unsupported value for EXLLAMA_VERSION: {EXLLAMA_VERSION}")
41+
HAS_GPTQ_CUDA = True
42+
GPTQ_CUDA_LINEAR = ExllamaQuantLinear
43+
except ImportError as e:
44+
print_rank_n(f"Error importing ExllamaV{EXLLAMA_VERSION} kernels: {e}")
45+
EXLLAMA_VERSION = None
46+
elif GPTQ_CUDA_TYPE == "marlin":
47+
try:
48+
from text_generation_server.utils.gptq.marlin import MarlinQuantLinear
49+
GPTQ_CUDA_LINEAR = MarlinQuantLinear
50+
HAS_GPTQ_CUDA = True
51+
except ImportError as e:
52+
print_rank_n(f"Error importing Marlin kernels: {e}")
53+
else:
54+
print_rank_n(f"Invalid GPTQ_CUDA_TYPE {GPTQ_CUDA_TYPE}")
4055

4156
print_rank_n(
42-
f"HAS_BITS_AND_BYTES={HAS_BITS_AND_BYTES}, HAS_EXLLAMA={HAS_EXLLAMA}, EXLLAMA_VERSION={EXLLAMA_VERSION}"
57+
f"HAS_BITS_AND_BYTES={HAS_BITS_AND_BYTES}, HAS_GPTQ_CUDA={HAS_GPTQ_CUDA}, EXLLAMA_VERSION={EXLLAMA_VERSION}, GPTQ_CUDA_TYPE={GPTQ_CUDA_TYPE}"
4358
)
4459

45-
4660
# Monkey patching
4761
@classmethod
4862
def load_layer_norm(cls, prefix, weights, eps):
@@ -169,13 +183,13 @@ def get_linear(weight, bias, quantize):
169183
linear.bias = nn.Parameter(bias)
170184
elif quantize == "gptq":
171185
try:
172-
qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight
186+
qweight, qzeros, scales, g_idx, bits, groupsize, use_gptq_cuda = weight
173187
except Exception:
174188
raise NotImplementedError(
175189
f"The passed weight is not `gptq` compatible, loader needs to be updated."
176190
)
177-
178-
linear = (ExllamaQuantLinear if use_exllama else QuantLinear)(
191+
192+
linear = (QuantLinear if not use_gptq_cuda else GPTQ_CUDA_LINEAR)(
179193
qweight,
180194
qzeros,
181195
scales,

server/text_generation_server/utils/weights.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -127,15 +127,15 @@ def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
127127
g_idx = w[0]
128128

129129
bits, groupsize = self._get_gptq_params()
130-
use_exllama = False
130+
use_gptq_cuda = False
131131
if bits == 4:
132-
from text_generation_server.utils.layers import HAS_EXLLAMA
132+
from text_generation_server.utils.layers import HAS_GPTQ_CUDA
133133

134-
use_exllama = HAS_EXLLAMA
135-
if use_exllama:
136-
logger.info(f"Using exllama kernels for col {prefixes}")
134+
use_gptq_cuda = HAS_GPTQ_CUDA
135+
if use_gptq_cuda:
136+
logger.info(f"Using GPTQ cuda kernels for col {prefixes}")
137137

138-
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
138+
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_gptq_cuda)
139139
else:
140140
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
141141
weight = torch.cat(w, dim=dim)
@@ -145,34 +145,34 @@ def get_multi_weights_row(self, prefix: str, quantize: str):
145145
if quantize == "gptq":
146146
bits, groupsize = self._get_gptq_params()
147147

148-
use_exllama = bits == 4
148+
use_gptq_cuda = bits == 4
149149

150150
if self.process_group.size() > 1:
151151
g_idx = self.get_tensor(f"{prefix}.g_idx")
152152
if g_idx is not None:
153153
if not torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32)) and not (g_idx == 0).all():
154154
# Exllama implementation does not support row tensor parallelism with act-order, as
155155
# it would require to reorder input activations that are split unto several GPUs
156-
use_exllama = False
156+
use_gptq_cuda = False
157157

158158
try:
159159
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
160160
except RuntimeError:
161161
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
162162

163-
from text_generation_server.utils.layers import HAS_EXLLAMA
164-
if use_exllama:
165-
use_exllama = HAS_EXLLAMA
163+
from text_generation_server.utils.layers import HAS_GPTQ_CUDA
164+
if use_gptq_cuda:
165+
use_gptq_cuda = HAS_GPTQ_CUDA
166166
if self.process_group.rank == 0:
167-
if use_exllama:
168-
logger.info(f"Using exllama kernels for row {prefix}")
167+
if use_gptq_cuda:
168+
logger.info(f"Using GPTQ cuda kernels for row {prefix}")
169169
else:
170170
logger.warning(
171-
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are disabled via the DISABLE_EXLLAMA env var,"
171+
"GPTQ cuda kernels (which are faster) could have been used, but are disabled via the DISABLE_EXLLAMA env var,"
172172
" or not currently installed, try using BUILD_EXTENSIONS=True"
173173
)
174174

175-
if use_exllama:
175+
if use_gptq_cuda:
176176
if groupsize >= 0:
177177
# Exllama reorders the weights in advance and the activations on the fly, thus
178178
# the scales and zero-points do not need to be reordered.
@@ -195,7 +195,7 @@ def get_multi_weights_row(self, prefix: str, quantize: str):
195195
scales = self.get_tensor(f"{prefix}.scales")
196196
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
197197

198-
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
198+
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_gptq_cuda)
199199
else:
200200
weight = self.get_sharded(f"{prefix}.weight", dim=1)
201201
return weight

0 commit comments

Comments
 (0)