Skip to content

Re: Incoporate Marlin for GPTQ checkpoints into tgis_native #66

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def _load_multi_mqa_gptq(
g_idx = g_idx.to(device=weights.device)
bits, groupsize = weights._get_gptq_params()

from text_generation_server.utils.layers import HAS_EXLLAMA
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, HAS_EXLLAMA)
from text_generation_server.utils.layers import HAS_GPTQ_CUDA
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, HAS_GPTQ_CUDA)

if bias:
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
Expand Down
11 changes: 5 additions & 6 deletions server/text_generation_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,29 +277,28 @@ async def serve_inner(
print(model.config.__str__())

if quantize == "gptq" and deployment_framework == "tgis_native":
from text_generation_server.utils.layers import HAS_EXLLAMA, EXLLAMA_VERSION
if HAS_EXLLAMA:
from text_generation_server.utils.layers import HAS_GPTQ_CUDA, EXLLAMA_VERSION
if HAS_GPTQ_CUDA and EXLLAMA_VERSION is not None:
try:
# When using GPTQ, Exllama kernels need some global kernels
# For which we have the final shapes only after the model has loaded
# This will allocate those buffers.

if EXLLAMA_VERSION == "1":
from text_generation_server.utils.gptq.exllama import (
create_exllama_buffers, set_device,
)
set_device(device)
create_exllama_buffers(max_sequence_length)
else:
assert EXLLAMA_VERSION == "2"
elif EXLLAMA_VERSION == "2":
from text_generation_server.utils.gptq.exllamav2 import (
set_device, Ex4bitLinearV2,
)
set_device(device)
for _, submodule in model.model.named_modules():
if isinstance(submodule, Ex4bitLinearV2):
submodule.post_init() # make q matrix and set scratch space

else:
raise ValueError(f"Unsupported {EXLLAMA_VERSION=}")
except ImportError:
print("WARN: Error setting up GPTQ exllama buffers")

Expand Down
187 changes: 187 additions & 0 deletions server/text_generation_server/utils/gptq/marlin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
# Adapted from https://github.com/AutoGPTQ/AutoGPTQ/blob/main/auto_gptq/nn_modules/qlinear/qlinear_marlin.py

import numpy as np
import torch
import torch.nn as nn

try:
import autogptq_marlin_cuda
except ImportError as e:
marlin_import_exception = e

def error_raiser_marlin(*args, **kwargs):
raise ValueError(
f"Trying to use the marlin backend, but could not import the C++/CUDA dependencies with the following error: {marlin_import_exception}"
)

autogptq_marlin_cuda = error_raiser_marlin


def mul(A, B, C, s, workspace, thread_k=-1, thread_n=-1, sms=-1, max_par=16):
"""Marlin FP16xINT4 multiply; can be used within `torch.compile`.
@A: `torch.half` input matrix of shape `(m, k)` in standard row-major layout
@B: `torch.int` weight matrix of original shape `(k, n)` in Marlin format; see `Layer.pack()`
@C: `torch.half` out matrix of shape `(m, n)` in standard row-major layout
@s: `torch.half` scales of shape `(m / group_size, n)`
@workspace: `torch.int` tensor with at least `n / 128 * max_par` entries that are all zero
@thread_k: `k` size of a thread_tile in `B` (can usually be left as auto -1)
@thread_n: `n` size of a thread_tile in `B` (can usually be left as auto -1)
@sms: number of SMs to use for the kernel (can usually be left as auto -1)
@max_par: maximum number of batch 64 problems to solve in parallel for large input sizes
"""
autogptq_marlin_cuda.mul(A, B, C, s, workspace, thread_k, thread_n, sms, max_par)


# Precompute permutations for Marlin weight and scale shuffling


def _get_perms():
perm = []
for i in range(32):
perm1 = []
col = i // 4
for block in [0, 1]:
for row in [
2 * (i % 4),
2 * (i % 4) + 1,
2 * (i % 4 + 4),
2 * (i % 4 + 4) + 1,
]:
perm1.append(16 * row + col + 8 * block)
for j in range(4):
perm.extend([p + 256 * j for p in perm1])

perm = np.array(perm)
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
perm = perm.reshape((-1, 8))[:, interleave].ravel()
perm = torch.from_numpy(perm)
scale_perm = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single = []
for i in range(4):
scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return perm, scale_perm, scale_perm_single

# _perm, _scale_perm, _scale_perm_single = _get_perms()

# def unpack_qzeros(qzeros):
# unpacked_zeros = torch.zeros(
# (qzeros.shape[0], qzeros.shape[1] * 8),
# dtype=torch.int8,
# device=qzeros.device,
# requires_grad=False,
# )

# for col in range(unpacked_zeros.shape[1]):
# i = col % 8
# unpacked_zeros[:, col] = (qzeros[:, col // 8] >> (4 * i)) & 0xF

# return unpacked_zeros + 1

def pack(x, nbits=4):
pack_size = 32 // nbits
out = torch.zeros((x.shape[0]//pack_size, x.shape[1]), dtype=x.dtype, device=x.device)
bitmask = 2**nbits - 1
for i in range(pack_size):
out |= (x[i::pack_size] & bitmask) << (nbits*i)
return out

def unpack(x, nbits=4, axis=0):
assert nbits == 4
bitmask = 2**nbits - 1
pack_size = 32 // nbits
dim0_size = x.shape[0] * pack_size if axis == 0 else x.shape[0]
dim1_size = x.shape[1] * pack_size if axis == 1 else x.shape[1]
output = torch.empty((dim0_size, dim1_size), dtype=x.dtype, layout=x.layout, device=x.device)

if axis == 0:
for i in range(pack_size):
output[i::pack_size, :] = (x >> (i*nbits)) & bitmask
elif axis == 1:
for i in range(pack_size):
output[:, i::pack_size] = (x >> (i*nbits)) & bitmask
else:
assert False, "invalid unpack axis"
return output


class MarlinQuantLinear(nn.Module):
QUANT_TYPE = "marlin"

def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, group_size):
super().__init__()

pack_size = 32 // bits
infeatures = qweight.shape[0] * pack_size
outfeatures = qweight.shape[1]

device_capability = torch.cuda.get_device_capability()
if not device_capability[0] >= 8:
raise ValueError(f'Can not use Marlin int4*fp16 kernel with a device of compute capability {device_capability}.')
if infeatures % 128 != 0 or outfeatures % 256 != 0:
raise ValueError("`infeatures` must be divisible by 128 and `outfeatures` by 256.")
if bits not in [4]:
raise NotImplementedError("Only 4 bits are supported.")
if group_size not in [-1, 128] and group_size != infeatures:
raise ValueError("Only group_size -1 and 128 are supported.")
if infeatures % group_size != 0:
raise ValueError("`infeatures` must be divisible by `group_size`.")

self.infeatures = infeatures
self.outfeatures = outfeatures
self.group_size = group_size if group_size != -1 else infeatures

self.desc_act = not ( g_idx is None
or torch.equal(g_idx, torch.arange(infeatures, device=qweight.device) // group_size) )

if self.desc_act:
# shuffle weight rows
self.perm = torch.argsort(g_idx)
# unpack --> shuffle --> pack
qweight = pack(unpack(qweight)[self.perm])

# Repack into marlin format
self.B = autogptq_marlin_cuda.gptq_repack(qweight)

# # Check symmetric quantization, very slow, skipping for now
# dequantized_qzeros = unpack_qzeros(qzeros)
# if not torch.all(dequantized_qzeros == 8):
# raise ValueError(
# "Marlin kernel is compatible only with checkpoints using symetric quantization. "
# "Found non-symmetric quantization for the weight {name}."
# )

# Process scales
_, _scale_perm, _scale_perm_single = _get_perms()
s = scales.data.clone()
if group_size != infeatures:
s = s.reshape((1, -1))
s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm]
else:
s = s.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single]
s = s.reshape((-1, outfeatures)).contiguous()
self.s = s

# TODO: Can the workspace be shared among all marlin invocations?
self.workspace = torch.zeros(self.outfeatures // 128 * 16, dtype=torch.int, device=qweight.device)
self.bias = bias if bias is not None else None

def post_init(self):
pass

def forward(self, A):
A = A.half()
#Support activation reordering
if self.desc_act:
A = A[:, self.perm]
C = torch.empty(A.shape[:-1] + (self.s.shape[1],), dtype=A.dtype, device=A.device)
mul(
A.view((-1, A.shape[-1])),
self.B,
C.view((-1, C.shape[-1])),
self.s,
self.workspace,
)
C = C + self.bias if self.bias is not None else C
return C
52 changes: 33 additions & 19 deletions server/text_generation_server/utils/layers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
from enum import Enum

import torch
import torch.distributed

Expand All @@ -11,8 +13,10 @@
from accelerate import init_empty_weights

HAS_BITS_AND_BYTES = False
HAS_EXLLAMA = False
EXLLAMA_VERSION = None
HAS_GPTQ_CUDA = False
GPTQ_CUDA_TYPE = os.getenv("GPTQ_CUDA_TYPE", "exllama").lower()
GPTQ_CUDA_LINEAR = None

if torch.cuda.is_available():
try:
Expand All @@ -24,25 +28,35 @@

from text_generation_server.utils.gptq.quant_linear import QuantLinear

if os.getenv("DISABLE_EXLLAMA", "False").lower() != "true":
try:
EXLLAMA_VERSION = os.getenv("EXLLAMA_VERSION", "2") # Use v2 as default
if EXLLAMA_VERSION == "1":
from text_generation_server.utils.gptq.exllama import Ex4bitLinear as ExllamaQuantLinear
elif EXLLAMA_VERSION == "2":
from text_generation_server.utils.gptq.exllamav2 import Ex4bitLinearV2 as ExllamaQuantLinear
else:
raise ValueError(f"Unsupported value for EXLLAMA_VERSION: {EXLLAMA_VERSION}")
HAS_EXLLAMA = True
except ImportError as e:
print_rank_n(f"Error importing ExllamaV{EXLLAMA_VERSION} kernels: {e}")
EXLLAMA_VERSION = None
if os.getenv("DISABLE_EXLLAMA", "False").lower() != "true": # Turn off all GPTQ CUDA kernels if set to true
if GPTQ_CUDA_TYPE == "exllama":
try:
EXLLAMA_VERSION = os.getenv("EXLLAMA_VERSION", "2") # Use v2 as default
if EXLLAMA_VERSION == "1": # TODO: consider removing v1 kernel
from text_generation_server.utils.gptq.exllama import Ex4bitLinear as ExllamaQuantLinear
elif EXLLAMA_VERSION == "2":
from text_generation_server.utils.gptq.exllamav2 import Ex4bitLinearV2 as ExllamaQuantLinear
else:
raise ValueError(f"Unsupported value for EXLLAMA_VERSION: {EXLLAMA_VERSION}")
HAS_GPTQ_CUDA = True
GPTQ_CUDA_LINEAR = ExllamaQuantLinear
except ImportError as e:
print_rank_n(f"Error importing ExllamaV{EXLLAMA_VERSION} kernels: {e}")
EXLLAMA_VERSION = None
elif GPTQ_CUDA_TYPE == "marlin":
try:
from text_generation_server.utils.gptq.marlin import MarlinQuantLinear
GPTQ_CUDA_LINEAR = MarlinQuantLinear
HAS_GPTQ_CUDA = True
except ImportError as e:
print_rank_n(f"Error importing Marlin kernels: {e}")
else:
print_rank_n(f"Invalid GPTQ_CUDA_TYPE {GPTQ_CUDA_TYPE}")

print_rank_n(
f"HAS_BITS_AND_BYTES={HAS_BITS_AND_BYTES}, HAS_EXLLAMA={HAS_EXLLAMA}, EXLLAMA_VERSION={EXLLAMA_VERSION}"
f"HAS_BITS_AND_BYTES={HAS_BITS_AND_BYTES}, HAS_GPTQ_CUDA={HAS_GPTQ_CUDA}, EXLLAMA_VERSION={EXLLAMA_VERSION}, GPTQ_CUDA_TYPE={GPTQ_CUDA_TYPE}"
)


# Monkey patching
@classmethod
def load_layer_norm(cls, prefix, weights, eps):
Expand Down Expand Up @@ -169,13 +183,13 @@ def get_linear(weight, bias, quantize):
linear.bias = nn.Parameter(bias)
elif quantize == "gptq":
try:
qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight
qweight, qzeros, scales, g_idx, bits, groupsize, use_gptq_cuda = weight
except Exception:
raise NotImplementedError(
f"The passed weight is not `gptq` compatible, loader needs to be updated."
)

linear = (ExllamaQuantLinear if use_exllama else QuantLinear)(
linear = (QuantLinear if not use_gptq_cuda else GPTQ_CUDA_LINEAR)(
qweight,
qzeros,
scales,
Expand Down
32 changes: 16 additions & 16 deletions server/text_generation_server/utils/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,15 @@ def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
g_idx = w[0]

bits, groupsize = self._get_gptq_params()
use_exllama = False
use_gptq_cuda = False
if bits == 4:
from text_generation_server.utils.layers import HAS_EXLLAMA
from text_generation_server.utils.layers import HAS_GPTQ_CUDA

use_exllama = HAS_EXLLAMA
if use_exllama:
logger.info(f"Using exllama kernels for col {prefixes}")
use_gptq_cuda = HAS_GPTQ_CUDA
if use_gptq_cuda:
logger.info(f"Using GPTQ cuda kernels for col {prefixes}")

weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_gptq_cuda)
else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
weight = torch.cat(w, dim=dim)
Expand All @@ -145,34 +145,34 @@ def get_multi_weights_row(self, prefix: str, quantize: str):
if quantize == "gptq":
bits, groupsize = self._get_gptq_params()

use_exllama = bits == 4
use_gptq_cuda = bits == 4

if self.process_group.size() > 1:
g_idx = self.get_tensor(f"{prefix}.g_idx")
if g_idx is not None:
if not torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32)) and not (g_idx == 0).all():
# Exllama implementation does not support row tensor parallelism with act-order, as
# it would require to reorder input activations that are split unto several GPUs
use_exllama = False
use_gptq_cuda = False

try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")

from text_generation_server.utils.layers import HAS_EXLLAMA
if use_exllama:
use_exllama = HAS_EXLLAMA
from text_generation_server.utils.layers import HAS_GPTQ_CUDA
if use_gptq_cuda:
use_gptq_cuda = HAS_GPTQ_CUDA
if self.process_group.rank == 0:
if use_exllama:
logger.info(f"Using exllama kernels for row {prefix}")
if use_gptq_cuda:
logger.info(f"Using GPTQ cuda kernels for row {prefix}")
else:
logger.warning(
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are disabled via the DISABLE_EXLLAMA env var,"
"GPTQ cuda kernels (which are faster) could have been used, but are disabled via the DISABLE_EXLLAMA env var,"
" or not currently installed, try using BUILD_EXTENSIONS=True"
)

if use_exllama:
if use_gptq_cuda:
if groupsize >= 0:
# Exllama reorders the weights in advance and the activations on the fly, thus
# the scales and zero-points do not need to be reordered.
Expand All @@ -195,7 +195,7 @@ def get_multi_weights_row(self, prefix: str, quantize: str):
scales = self.get_tensor(f"{prefix}.scales")
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)

weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_gptq_cuda)
else:
weight = self.get_sharded(f"{prefix}.weight", dim=1)
return weight
Expand Down