Skip to content

Commit 0a240b3

Browse files
committed
Incoporate marlin into tgis_native
Signed-off-by: Chih-Chieh-Yang <[email protected]>
1 parent 91a9072 commit 0a240b3

File tree

5 files changed

+263
-62
lines changed

5 files changed

+263
-62
lines changed

server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ def _load_multi_mqa_gptq(
6868
g_idx = g_idx.to(device=weights.device)
6969
bits, groupsize = weights._get_gptq_params()
7070

71-
from text_generation_server.utils.layers import HAS_EXLLAMA
72-
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, HAS_EXLLAMA)
71+
from text_generation_server.utils.layers import HAS_GPTQ_CUDA
72+
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, HAS_GPTQ_CUDA)
7373

7474
if bias:
7575
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")

server/text_generation_server/server.py

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -277,31 +277,32 @@ async def serve_inner(
277277
print(model.config.__str__())
278278

279279
if quantize == "gptq" and deployment_framework == "tgis_native":
280-
from text_generation_server.utils.layers import HAS_EXLLAMA, EXLLAMA_VERSION
281-
if HAS_EXLLAMA:
282-
try:
283-
# When using GPTQ, Exllama kernels need some global kernels
284-
# For which we have the final shapes only after the model has loaded
285-
# This will allocate those buffers.
286-
287-
if EXLLAMA_VERSION == "1":
288-
from text_generation_server.utils.gptq.exllama import (
289-
create_exllama_buffers, set_device,
290-
)
291-
set_device(device)
292-
create_exllama_buffers(max_sequence_length)
293-
else:
294-
assert EXLLAMA_VERSION == "2"
295-
from text_generation_server.utils.gptq.exllamav2 import (
296-
set_device, Ex4bitLinearV2,
297-
)
298-
set_device(device)
299-
for _, submodule in model.model.named_modules():
300-
if isinstance(submodule, Ex4bitLinearV2):
301-
submodule.post_init() # make q matrix and set scratch space
302-
303-
except ImportError:
304-
print("WARN: Error setting up GPTQ exllama buffers")
280+
from text_generation_server.utils.layers import HAS_GPTQ_CUDA, EXLLAMA_VERSION
281+
if HAS_GPTQ_CUDA:
282+
if EXLLAMA_VERSION is not None:
283+
try:
284+
# When using GPTQ, Exllama kernels need some global kernels
285+
# For which we have the final shapes only after the model has loaded
286+
# This will allocate those buffers.
287+
288+
if EXLLAMA_VERSION == "1":
289+
from text_generation_server.utils.gptq.exllama import (
290+
create_exllama_buffers, set_device,
291+
)
292+
set_device(device)
293+
create_exllama_buffers(max_sequence_length)
294+
else:
295+
assert EXLLAMA_VERSION == "2"
296+
from text_generation_server.utils.gptq.exllamav2 import (
297+
set_device, Ex4bitLinearV2,
298+
)
299+
set_device(device)
300+
for _, submodule in model.model.named_modules():
301+
if isinstance(submodule, Ex4bitLinearV2):
302+
submodule.post_init() # make q matrix and set scratch space
303+
304+
except ImportError:
305+
print("WARN: Error setting up GPTQ exllama buffers")
305306

306307
if local_rank == 0 and device.type == "cuda":
307308
# Log GPU memory stats at startup
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# Adapted from https://github.com/AutoGPTQ/AutoGPTQ/blob/main/auto_gptq/nn_modules/qlinear/qlinear_marlin.py
2+
3+
import numpy as np
4+
import torch
5+
import torch.nn as nn
6+
7+
try:
8+
import autogptq_marlin_cuda
9+
except ImportError as e:
10+
marlin_import_exception = e
11+
12+
def error_raiser_marlin(*args, **kwargs):
13+
raise ValueError(
14+
f"Trying to use the marlin backend, but could not import the C++/CUDA dependencies with the following error: {marlin_import_exception}"
15+
)
16+
17+
autogptq_marlin_cuda = error_raiser_marlin
18+
19+
20+
def mul(A, B, C, s, workspace, thread_k=-1, thread_n=-1, sms=-1, max_par=16):
21+
"""Marlin FP16xINT4 multiply; can be used within `torch.compile`.
22+
@A: `torch.half` input matrix of shape `(m, k)` in standard row-major layout
23+
@B: `torch.int` weight matrix of original shape `(k, n)` in Marlin format; see `Layer.pack()`
24+
@C: `torch.half` out matrix of shape `(m, n)` in standard row-major layout
25+
@s: `torch.half` scales of shape `(m / group_size, n)`
26+
@workspace: `torch.int` tensor with at least `n / 128 * max_par` entries that are all zero
27+
@thread_k: `k` size of a thread_tile in `B` (can usually be left as auto -1)
28+
@thread_n: `n` size of a thread_tile in `B` (can usually be left as auto -1)
29+
@sms: number of SMs to use for the kernel (can usually be left as auto -1)
30+
@max_par: maximum number of batch 64 problems to solve in parallel for large input sizes
31+
"""
32+
autogptq_marlin_cuda.mul(A, B, C, s, workspace, thread_k, thread_n, sms, max_par)
33+
34+
35+
# Precompute permutations for Marlin weight and scale shuffling
36+
37+
38+
def _get_perms():
39+
perm = []
40+
for i in range(32):
41+
perm1 = []
42+
col = i // 4
43+
for block in [0, 1]:
44+
for row in [
45+
2 * (i % 4),
46+
2 * (i % 4) + 1,
47+
2 * (i % 4 + 4),
48+
2 * (i % 4 + 4) + 1,
49+
]:
50+
perm1.append(16 * row + col + 8 * block)
51+
for j in range(4):
52+
perm.extend([p + 256 * j for p in perm1])
53+
54+
perm = np.array(perm)
55+
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
56+
perm = perm.reshape((-1, 8))[:, interleave].ravel()
57+
perm = torch.from_numpy(perm)
58+
scale_perm = []
59+
for i in range(8):
60+
scale_perm.extend([i + 8 * j for j in range(8)])
61+
scale_perm_single = []
62+
for i in range(4):
63+
scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
64+
return perm, scale_perm, scale_perm_single
65+
66+
# _perm, _scale_perm, _scale_perm_single = _get_perms()
67+
68+
# def unpack_qzeros(qzeros):
69+
# unpacked_zeros = torch.zeros(
70+
# (qzeros.shape[0], qzeros.shape[1] * 8),
71+
# dtype=torch.int8,
72+
# device=qzeros.device,
73+
# requires_grad=False,
74+
# )
75+
76+
# for col in range(unpacked_zeros.shape[1]):
77+
# i = col % 8
78+
# unpacked_zeros[:, col] = (qzeros[:, col // 8] >> (4 * i)) & 0xF
79+
80+
# return unpacked_zeros + 1
81+
82+
def pack(x, nbits=4):
83+
pack_size = 32 // nbits
84+
out = torch.zeros((x.shape[0]//pack_size, x.shape[1]), dtype=x.dtype, device=x.device)
85+
bitmask = 2**nbits - 1
86+
for i in range(pack_size):
87+
out |= (x[i::pack_size] & bitmask) << (nbits*i)
88+
return out
89+
90+
def unpack(x, nbits=4, axis=0):
91+
assert nbits == 4
92+
bitmask = 2**nbits - 1
93+
pack_size = 32 // nbits
94+
dim0_size = x.shape[0] * pack_size if axis == 0 else x.shape[0]
95+
dim1_size = x.shape[1] * pack_size if axis == 1 else x.shape[1]
96+
output = torch.empty((dim0_size, dim1_size), dtype=x.dtype, layout=x.layout, device=x.device)
97+
98+
if axis == 0:
99+
for i in range(pack_size):
100+
output[i::pack_size, :] = (x >> (i*nbits)) & bitmask
101+
elif axis == 1:
102+
for i in range(pack_size):
103+
output[:, i::pack_size] = (x >> (i*nbits)) & bitmask
104+
else:
105+
assert False, "invalid unpack axis"
106+
return output
107+
108+
109+
class MarlinQuantLinear(nn.Module):
110+
QUANT_TYPE = "marlin"
111+
112+
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, group_size):
113+
super().__init__()
114+
115+
pack_size = 32 // bits
116+
infeatures = qweight.shape[0] * pack_size
117+
outfeatures = qweight.shape[1]
118+
119+
if not torch.cuda.get_device_capability()[0] >= 8:
120+
raise ValueError(f'Can not use Marlin int4*fp16 kernel with a device of compute capability {torch.cuda.get_device_capability()}.')
121+
if infeatures % 128 != 0 or outfeatures % 256 != 0:
122+
raise ValueError("`infeatures` must be divisible by 128 and `outfeatures` by 256.")
123+
if bits not in [4]:
124+
raise NotImplementedError("Only 4 bits are supported.")
125+
if group_size not in [-1, 128] and group_size != infeatures:
126+
raise ValueError("Only group_size -1 and 128 are supported.")
127+
if infeatures % group_size != 0:
128+
raise ValueError("`infeatures` must be divisible by `group_size`.")
129+
130+
self.infeatures = infeatures
131+
self.outfeatures = outfeatures
132+
self.group_size = group_size if group_size != -1 else infeatures
133+
134+
self.desc_act = not ( g_idx is None
135+
or torch.equal(g_idx, torch.arange(infeatures, device=qweight.device) // group_size) )
136+
137+
if self.desc_act:
138+
# shuffle weight rows
139+
self.perm = torch.argsort(g_idx)
140+
# unpack --> shuffle --> pack
141+
qweight = pack(unpack(qweight)[self.perm])
142+
143+
# Repack into marlin format
144+
self.B = autogptq_marlin_cuda.gptq_repack(qweight)
145+
146+
# # Check symmetric quantization, very slow, skipping for now
147+
# dequantized_qzeros = unpack_qzeros(qzeros)
148+
# if not torch.all(dequantized_qzeros == 8):
149+
# raise ValueError(
150+
# "Marlin kernel is compatible only with checkpoints using symetric quantization. "
151+
# "Found non-symmetric quantization for the weight {name}."
152+
# )
153+
154+
# Process scales
155+
_, _scale_perm, _scale_perm_single = _get_perms()
156+
s = scales.data.clone()
157+
if group_size != infeatures:
158+
s = s.reshape((1, -1))
159+
s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm]
160+
else:
161+
s = s.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single]
162+
s = s.reshape((-1, outfeatures)).contiguous()
163+
self.s = s
164+
165+
# TODO: Can the workspace be shared among all marlin invocations?
166+
self.workspace = torch.zeros(self.outfeatures // 128 * 16, dtype=torch.int, device=qweight.device)
167+
self.bias = bias if bias is not None else None
168+
169+
def post_init(self):
170+
pass
171+
172+
def forward(self, A):
173+
A = A.half()
174+
#Support activation reordering
175+
if self.desc_act:
176+
A = A[:, self.perm]
177+
C = torch.empty(A.shape[:-1] + (self.s.shape[1],), dtype=A.dtype, device=A.device)
178+
mul(
179+
A.view((-1, A.shape[-1])),
180+
self.B,
181+
C.view((-1, C.shape[-1])),
182+
self.s,
183+
self.workspace,
184+
)
185+
C = C + self.bias if self.bias is not None else C
186+
return C

server/text_generation_server/utils/layers.py

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import os
2+
from enum import Enum
3+
24
import torch
35
import torch.distributed
46

@@ -11,8 +13,10 @@
1113
from accelerate import init_empty_weights
1214

1315
HAS_BITS_AND_BYTES = False
14-
HAS_EXLLAMA = False
1516
EXLLAMA_VERSION = None
17+
HAS_GPTQ_CUDA = False
18+
GPTQ_CUDA_TYPE = os.getenv("GPTQ_CUDA_TYPE", "exllama").lower()
19+
GPTQ_CUDA_LINEAR = None
1620

1721
if torch.cuda.is_available():
1822
try:
@@ -24,25 +28,35 @@
2428

2529
from text_generation_server.utils.gptq.quant_linear import QuantLinear
2630

27-
if os.getenv("DISABLE_EXLLAMA", "False").lower() != "true":
28-
try:
29-
EXLLAMA_VERSION = os.getenv("EXLLAMA_VERSION", "2") # Use v2 as default
30-
if EXLLAMA_VERSION == "1":
31-
from text_generation_server.utils.gptq.exllama import Ex4bitLinear as ExllamaQuantLinear
32-
elif EXLLAMA_VERSION == "2":
33-
from text_generation_server.utils.gptq.exllamav2 import Ex4bitLinearV2 as ExllamaQuantLinear
34-
else:
35-
raise ValueError(f"Unsupported value for EXLLAMA_VERSION: {EXLLAMA_VERSION}")
36-
HAS_EXLLAMA = True
37-
except ImportError as e:
38-
print_rank_n(f"Error importing ExllamaV{EXLLAMA_VERSION} kernels: {e}")
39-
EXLLAMA_VERSION = None
31+
if os.getenv("DISABLE_EXLLAMA", "False").lower() != "true": # Turn off all GPTQ CUDA kernels if set to true
32+
if GPTQ_CUDA_TYPE == "exllama":
33+
try:
34+
EXLLAMA_VERSION = os.getenv("EXLLAMA_VERSION", "2") # Use v2 as default
35+
if EXLLAMA_VERSION == "1": # TODO: consider removing v1 kernel
36+
from text_generation_server.utils.gptq.exllama import Ex4bitLinear as ExllamaQuantLinear
37+
elif EXLLAMA_VERSION == "2":
38+
from text_generation_server.utils.gptq.exllamav2 import Ex4bitLinearV2 as ExllamaQuantLinear
39+
else:
40+
raise ValueError(f"Unsupported value for EXLLAMA_VERSION: {EXLLAMA_VERSION}")
41+
HAS_GPTQ_CUDA = True
42+
GPTQ_CUDA_LINEAR = ExllamaQuantLinear
43+
except ImportError as e:
44+
print_rank_n(f"Error importing ExllamaV{EXLLAMA_VERSION} kernels: {e}")
45+
EXLLAMA_VERSION = None
46+
elif GPTQ_CUDA_TYPE == "marlin":
47+
try:
48+
from text_generation_server.utils.gptq.marlin import MarlinQuantLinear
49+
GPTQ_CUDA_LINEAR = MarlinQuantLinear
50+
HAS_GPTQ_CUDA = True
51+
except ImportError as e:
52+
print_rank_n(f"Error importing Marlin kernels: {e}")
53+
else:
54+
print_rank_n(f"Invalid GPTQ_CUDA_TYPE {GPTQ_CUDA_TYPE}")
4055

4156
print_rank_n(
42-
f"HAS_BITS_AND_BYTES={HAS_BITS_AND_BYTES}, HAS_EXLLAMA={HAS_EXLLAMA}, EXLLAMA_VERSION={EXLLAMA_VERSION}"
57+
f"HAS_BITS_AND_BYTES={HAS_BITS_AND_BYTES}, HAS_GPTQ_CUDA={HAS_GPTQ_CUDA}, EXLLAMA_VERSION={EXLLAMA_VERSION}, GPTQ_CUDA_TYPE={GPTQ_CUDA_TYPE}"
4358
)
4459

45-
4660
# Monkey patching
4761
@classmethod
4862
def load_layer_norm(cls, prefix, weights, eps):
@@ -169,13 +183,13 @@ def get_linear(weight, bias, quantize):
169183
linear.bias = nn.Parameter(bias)
170184
elif quantize == "gptq":
171185
try:
172-
qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight
186+
qweight, qzeros, scales, g_idx, bits, groupsize, use_gptq_cuda = weight
173187
except Exception:
174188
raise NotImplementedError(
175189
f"The passed weight is not `gptq` compatible, loader needs to be updated."
176190
)
177-
178-
linear = (ExllamaQuantLinear if use_exllama else QuantLinear)(
191+
192+
linear = (QuantLinear if not use_gptq_cuda else GPTQ_CUDA_LINEAR)(
179193
qweight,
180194
qzeros,
181195
scales,

0 commit comments

Comments
 (0)