From 0a240b37cac5aaf50f9305331aa3ffda23652107 Mon Sep 17 00:00:00 2001 From: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com> Date: Fri, 8 Mar 2024 19:57:49 +0000 Subject: [PATCH 1/5] Incoporate marlin into tgis_native Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> --- .../flash_santacoder_modeling.py | 4 +- server/text_generation_server/server.py | 51 ++--- .../utils/gptq/marlin.py | 186 ++++++++++++++++++ server/text_generation_server/utils/layers.py | 52 +++-- .../text_generation_server/utils/weights.py | 32 +-- 5 files changed, 263 insertions(+), 62 deletions(-) create mode 100644 server/text_generation_server/utils/gptq/marlin.py diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 12fcd4dc..2e75070b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -68,8 +68,8 @@ def _load_multi_mqa_gptq( g_idx = g_idx.to(device=weights.device) bits, groupsize = weights._get_gptq_params() - from text_generation_server.utils.layers import HAS_EXLLAMA - weight = (qweight, qzeros, scales, g_idx, bits, groupsize, HAS_EXLLAMA) + from text_generation_server.utils.layers import HAS_GPTQ_CUDA + weight = (qweight, qzeros, scales, g_idx, bits, groupsize, HAS_GPTQ_CUDA) if bias: slice_ = weights._get_slice(f"{prefix}.c_attn.bias") diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 3a2b02a2..7d295caa 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -277,31 +277,32 @@ async def serve_inner( print(model.config.__str__()) if quantize == "gptq" and deployment_framework == "tgis_native": - from text_generation_server.utils.layers import HAS_EXLLAMA, EXLLAMA_VERSION - if HAS_EXLLAMA: - try: - # When using GPTQ, Exllama kernels need some global kernels - # For which we have the final shapes only after the model has loaded - # This will allocate those buffers. - - if EXLLAMA_VERSION == "1": - from text_generation_server.utils.gptq.exllama import ( - create_exllama_buffers, set_device, - ) - set_device(device) - create_exllama_buffers(max_sequence_length) - else: - assert EXLLAMA_VERSION == "2" - from text_generation_server.utils.gptq.exllamav2 import ( - set_device, Ex4bitLinearV2, - ) - set_device(device) - for _, submodule in model.model.named_modules(): - if isinstance(submodule, Ex4bitLinearV2): - submodule.post_init() # make q matrix and set scratch space - - except ImportError: - print("WARN: Error setting up GPTQ exllama buffers") + from text_generation_server.utils.layers import HAS_GPTQ_CUDA, EXLLAMA_VERSION + if HAS_GPTQ_CUDA: + if EXLLAMA_VERSION is not None: + try: + # When using GPTQ, Exllama kernels need some global kernels + # For which we have the final shapes only after the model has loaded + # This will allocate those buffers. + + if EXLLAMA_VERSION == "1": + from text_generation_server.utils.gptq.exllama import ( + create_exllama_buffers, set_device, + ) + set_device(device) + create_exllama_buffers(max_sequence_length) + else: + assert EXLLAMA_VERSION == "2" + from text_generation_server.utils.gptq.exllamav2 import ( + set_device, Ex4bitLinearV2, + ) + set_device(device) + for _, submodule in model.model.named_modules(): + if isinstance(submodule, Ex4bitLinearV2): + submodule.post_init() # make q matrix and set scratch space + + except ImportError: + print("WARN: Error setting up GPTQ exllama buffers") if local_rank == 0 and device.type == "cuda": # Log GPU memory stats at startup diff --git a/server/text_generation_server/utils/gptq/marlin.py b/server/text_generation_server/utils/gptq/marlin.py new file mode 100644 index 00000000..5e043b36 --- /dev/null +++ b/server/text_generation_server/utils/gptq/marlin.py @@ -0,0 +1,186 @@ +# Adapted from https://github.com/AutoGPTQ/AutoGPTQ/blob/main/auto_gptq/nn_modules/qlinear/qlinear_marlin.py + +import numpy as np +import torch +import torch.nn as nn + +try: + import autogptq_marlin_cuda +except ImportError as e: + marlin_import_exception = e + + def error_raiser_marlin(*args, **kwargs): + raise ValueError( + f"Trying to use the marlin backend, but could not import the C++/CUDA dependencies with the following error: {marlin_import_exception}" + ) + + autogptq_marlin_cuda = error_raiser_marlin + + +def mul(A, B, C, s, workspace, thread_k=-1, thread_n=-1, sms=-1, max_par=16): + """Marlin FP16xINT4 multiply; can be used within `torch.compile`. + @A: `torch.half` input matrix of shape `(m, k)` in standard row-major layout + @B: `torch.int` weight matrix of original shape `(k, n)` in Marlin format; see `Layer.pack()` + @C: `torch.half` out matrix of shape `(m, n)` in standard row-major layout + @s: `torch.half` scales of shape `(m / group_size, n)` + @workspace: `torch.int` tensor with at least `n / 128 * max_par` entries that are all zero + @thread_k: `k` size of a thread_tile in `B` (can usually be left as auto -1) + @thread_n: `n` size of a thread_tile in `B` (can usually be left as auto -1) + @sms: number of SMs to use for the kernel (can usually be left as auto -1) + @max_par: maximum number of batch 64 problems to solve in parallel for large input sizes + """ + autogptq_marlin_cuda.mul(A, B, C, s, workspace, thread_k, thread_n, sms, max_par) + + +# Precompute permutations for Marlin weight and scale shuffling + + +def _get_perms(): + perm = [] + for i in range(32): + perm1 = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm.extend([p + 256 * j for p in perm1]) + + perm = np.array(perm) + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + perm = perm.reshape((-1, 8))[:, interleave].ravel() + perm = torch.from_numpy(perm) + scale_perm = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single = [] + for i in range(4): + scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return perm, scale_perm, scale_perm_single + +# _perm, _scale_perm, _scale_perm_single = _get_perms() + +# def unpack_qzeros(qzeros): +# unpacked_zeros = torch.zeros( +# (qzeros.shape[0], qzeros.shape[1] * 8), +# dtype=torch.int8, +# device=qzeros.device, +# requires_grad=False, +# ) + +# for col in range(unpacked_zeros.shape[1]): +# i = col % 8 +# unpacked_zeros[:, col] = (qzeros[:, col // 8] >> (4 * i)) & 0xF + +# return unpacked_zeros + 1 + +def pack(x, nbits=4): + pack_size = 32 // nbits + out = torch.zeros((x.shape[0]//pack_size, x.shape[1]), dtype=x.dtype, device=x.device) + bitmask = 2**nbits - 1 + for i in range(pack_size): + out |= (x[i::pack_size] & bitmask) << (nbits*i) + return out + +def unpack(x, nbits=4, axis=0): + assert nbits == 4 + bitmask = 2**nbits - 1 + pack_size = 32 // nbits + dim0_size = x.shape[0] * pack_size if axis == 0 else x.shape[0] + dim1_size = x.shape[1] * pack_size if axis == 1 else x.shape[1] + output = torch.empty((dim0_size, dim1_size), dtype=x.dtype, layout=x.layout, device=x.device) + + if axis == 0: + for i in range(pack_size): + output[i::pack_size, :] = (x >> (i*nbits)) & bitmask + elif axis == 1: + for i in range(pack_size): + output[:, i::pack_size] = (x >> (i*nbits)) & bitmask + else: + assert False, "invalid unpack axis" + return output + + +class MarlinQuantLinear(nn.Module): + QUANT_TYPE = "marlin" + + def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, group_size): + super().__init__() + + pack_size = 32 // bits + infeatures = qweight.shape[0] * pack_size + outfeatures = qweight.shape[1] + + if not torch.cuda.get_device_capability()[0] >= 8: + raise ValueError(f'Can not use Marlin int4*fp16 kernel with a device of compute capability {torch.cuda.get_device_capability()}.') + if infeatures % 128 != 0 or outfeatures % 256 != 0: + raise ValueError("`infeatures` must be divisible by 128 and `outfeatures` by 256.") + if bits not in [4]: + raise NotImplementedError("Only 4 bits are supported.") + if group_size not in [-1, 128] and group_size != infeatures: + raise ValueError("Only group_size -1 and 128 are supported.") + if infeatures % group_size != 0: + raise ValueError("`infeatures` must be divisible by `group_size`.") + + self.infeatures = infeatures + self.outfeatures = outfeatures + self.group_size = group_size if group_size != -1 else infeatures + + self.desc_act = not ( g_idx is None + or torch.equal(g_idx, torch.arange(infeatures, device=qweight.device) // group_size) ) + + if self.desc_act: + # shuffle weight rows + self.perm = torch.argsort(g_idx) + # unpack --> shuffle --> pack + qweight = pack(unpack(qweight)[self.perm]) + + # Repack into marlin format + self.B = autogptq_marlin_cuda.gptq_repack(qweight) + + # # Check symmetric quantization, very slow, skipping for now + # dequantized_qzeros = unpack_qzeros(qzeros) + # if not torch.all(dequantized_qzeros == 8): + # raise ValueError( + # "Marlin kernel is compatible only with checkpoints using symetric quantization. " + # "Found non-symmetric quantization for the weight {name}." + # ) + + # Process scales + _, _scale_perm, _scale_perm_single = _get_perms() + s = scales.data.clone() + if group_size != infeatures: + s = s.reshape((1, -1)) + s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm] + else: + s = s.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single] + s = s.reshape((-1, outfeatures)).contiguous() + self.s = s + + # TODO: Can the workspace be shared among all marlin invocations? + self.workspace = torch.zeros(self.outfeatures // 128 * 16, dtype=torch.int, device=qweight.device) + self.bias = bias if bias is not None else None + + def post_init(self): + pass + + def forward(self, A): + A = A.half() + #Support activation reordering + if self.desc_act: + A = A[:, self.perm] + C = torch.empty(A.shape[:-1] + (self.s.shape[1],), dtype=A.dtype, device=A.device) + mul( + A.view((-1, A.shape[-1])), + self.B, + C.view((-1, C.shape[-1])), + self.s, + self.workspace, + ) + C = C + self.bias if self.bias is not None else C + return C diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index a817da1f..312f4d5d 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -1,4 +1,6 @@ import os +from enum import Enum + import torch import torch.distributed @@ -11,8 +13,10 @@ from accelerate import init_empty_weights HAS_BITS_AND_BYTES = False -HAS_EXLLAMA = False EXLLAMA_VERSION = None +HAS_GPTQ_CUDA = False +GPTQ_CUDA_TYPE = os.getenv("GPTQ_CUDA_TYPE", "exllama").lower() +GPTQ_CUDA_LINEAR = None if torch.cuda.is_available(): try: @@ -24,25 +28,35 @@ from text_generation_server.utils.gptq.quant_linear import QuantLinear - if os.getenv("DISABLE_EXLLAMA", "False").lower() != "true": - try: - EXLLAMA_VERSION = os.getenv("EXLLAMA_VERSION", "2") # Use v2 as default - if EXLLAMA_VERSION == "1": - from text_generation_server.utils.gptq.exllama import Ex4bitLinear as ExllamaQuantLinear - elif EXLLAMA_VERSION == "2": - from text_generation_server.utils.gptq.exllamav2 import Ex4bitLinearV2 as ExllamaQuantLinear - else: - raise ValueError(f"Unsupported value for EXLLAMA_VERSION: {EXLLAMA_VERSION}") - HAS_EXLLAMA = True - except ImportError as e: - print_rank_n(f"Error importing ExllamaV{EXLLAMA_VERSION} kernels: {e}") - EXLLAMA_VERSION = None + if os.getenv("DISABLE_EXLLAMA", "False").lower() != "true": # Turn off all GPTQ CUDA kernels if set to true + if GPTQ_CUDA_TYPE == "exllama": + try: + EXLLAMA_VERSION = os.getenv("EXLLAMA_VERSION", "2") # Use v2 as default + if EXLLAMA_VERSION == "1": # TODO: consider removing v1 kernel + from text_generation_server.utils.gptq.exllama import Ex4bitLinear as ExllamaQuantLinear + elif EXLLAMA_VERSION == "2": + from text_generation_server.utils.gptq.exllamav2 import Ex4bitLinearV2 as ExllamaQuantLinear + else: + raise ValueError(f"Unsupported value for EXLLAMA_VERSION: {EXLLAMA_VERSION}") + HAS_GPTQ_CUDA = True + GPTQ_CUDA_LINEAR = ExllamaQuantLinear + except ImportError as e: + print_rank_n(f"Error importing ExllamaV{EXLLAMA_VERSION} kernels: {e}") + EXLLAMA_VERSION = None + elif GPTQ_CUDA_TYPE == "marlin": + try: + from text_generation_server.utils.gptq.marlin import MarlinQuantLinear + GPTQ_CUDA_LINEAR = MarlinQuantLinear + HAS_GPTQ_CUDA = True + except ImportError as e: + print_rank_n(f"Error importing Marlin kernels: {e}") + else: + print_rank_n(f"Invalid GPTQ_CUDA_TYPE {GPTQ_CUDA_TYPE}") print_rank_n( - f"HAS_BITS_AND_BYTES={HAS_BITS_AND_BYTES}, HAS_EXLLAMA={HAS_EXLLAMA}, EXLLAMA_VERSION={EXLLAMA_VERSION}" + f"HAS_BITS_AND_BYTES={HAS_BITS_AND_BYTES}, HAS_GPTQ_CUDA={HAS_GPTQ_CUDA}, EXLLAMA_VERSION={EXLLAMA_VERSION}, GPTQ_CUDA_TYPE={GPTQ_CUDA_TYPE}" ) - # Monkey patching @classmethod def load_layer_norm(cls, prefix, weights, eps): @@ -169,13 +183,13 @@ def get_linear(weight, bias, quantize): linear.bias = nn.Parameter(bias) elif quantize == "gptq": try: - qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight + qweight, qzeros, scales, g_idx, bits, groupsize, use_gptq_cuda = weight except Exception: raise NotImplementedError( f"The passed weight is not `gptq` compatible, loader needs to be updated." ) - - linear = (ExllamaQuantLinear if use_exllama else QuantLinear)( + + linear = (QuantLinear if not use_gptq_cuda else GPTQ_CUDA_LINEAR)( qweight, qzeros, scales, diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 440c594d..3a53eb36 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -127,15 +127,15 @@ def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): g_idx = w[0] bits, groupsize = self._get_gptq_params() - use_exllama = False + use_gptq_cuda = False if bits == 4: - from text_generation_server.utils.layers import HAS_EXLLAMA + from text_generation_server.utils.layers import HAS_GPTQ_CUDA - use_exllama = HAS_EXLLAMA - if use_exllama: - logger.info(f"Using exllama kernels for col {prefixes}") + use_gptq_cuda = HAS_GPTQ_CUDA + if use_gptq_cuda: + logger.info(f"Using GPTQ cuda kernels for col {prefixes}") - weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) + weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_gptq_cuda) else: w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] weight = torch.cat(w, dim=dim) @@ -145,7 +145,7 @@ def get_multi_weights_row(self, prefix: str, quantize: str): if quantize == "gptq": bits, groupsize = self._get_gptq_params() - use_exllama = bits == 4 + use_gptq_cuda = bits == 4 if self.process_group.size() > 1: g_idx = self.get_tensor(f"{prefix}.g_idx") @@ -153,26 +153,26 @@ def get_multi_weights_row(self, prefix: str, quantize: str): if not torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32)) and not (g_idx == 0).all(): # Exllama implementation does not support row tensor parallelism with act-order, as # it would require to reorder input activations that are split unto several GPUs - use_exllama = False + use_gptq_cuda = False try: qweight = self.get_sharded(f"{prefix}.qweight", dim=0) except RuntimeError: raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`") - from text_generation_server.utils.layers import HAS_EXLLAMA - if use_exllama: - use_exllama = HAS_EXLLAMA + from text_generation_server.utils.layers import HAS_GPTQ_CUDA + if use_gptq_cuda: + use_gptq_cuda = HAS_GPTQ_CUDA if self.process_group.rank == 0: - if use_exllama: - logger.info(f"Using exllama kernels for row {prefix}") + if use_gptq_cuda: + logger.info(f"Using GPTQ cuda kernels for row {prefix}") else: logger.warning( - "Exllama GPTQ cuda kernels (which are faster) could have been used, but are disabled via the DISABLE_EXLLAMA env var," + "GPTQ cuda kernels (which are faster) could have been used, but are disabled via the DISABLE_EXLLAMA env var," " or not currently installed, try using BUILD_EXTENSIONS=True" ) - if use_exllama: + if use_gptq_cuda: if groupsize >= 0: # Exllama reorders the weights in advance and the activations on the fly, thus # the scales and zero-points do not need to be reordered. @@ -195,7 +195,7 @@ def get_multi_weights_row(self, prefix: str, quantize: str): scales = self.get_tensor(f"{prefix}.scales") g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) - weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) + weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_gptq_cuda) else: weight = self.get_sharded(f"{prefix}.weight", dim=1) return weight From a1a380919a0124c43add2dd11b9c8cfa0829c5b8 Mon Sep 17 00:00:00 2001 From: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com> Date: Fri, 8 Mar 2024 18:36:32 -0500 Subject: [PATCH 2/5] Enable marlin as default GPTQ kernel Signed-off-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com> Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> --- server/text_generation_server/utils/layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 312f4d5d..f6a78a1e 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -15,7 +15,7 @@ HAS_BITS_AND_BYTES = False EXLLAMA_VERSION = None HAS_GPTQ_CUDA = False -GPTQ_CUDA_TYPE = os.getenv("GPTQ_CUDA_TYPE", "exllama").lower() +GPTQ_CUDA_TYPE = os.getenv("GPTQ_CUDA_TYPE", "marlin").lower() GPTQ_CUDA_LINEAR = None if torch.cuda.is_available(): From 66af240579ea68af7646df1703f3f9de263bf298 Mon Sep 17 00:00:00 2001 From: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com> Date: Fri, 8 Mar 2024 21:28:28 -0500 Subject: [PATCH 3/5] Update server/text_generation_server/utils/gptq/marlin.py Co-authored-by: Nick Hill Signed-off-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com> Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> --- server/text_generation_server/utils/gptq/marlin.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/utils/gptq/marlin.py b/server/text_generation_server/utils/gptq/marlin.py index 5e043b36..d5a71c02 100644 --- a/server/text_generation_server/utils/gptq/marlin.py +++ b/server/text_generation_server/utils/gptq/marlin.py @@ -116,8 +116,9 @@ def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, group_size): infeatures = qweight.shape[0] * pack_size outfeatures = qweight.shape[1] - if not torch.cuda.get_device_capability()[0] >= 8: - raise ValueError(f'Can not use Marlin int4*fp16 kernel with a device of compute capability {torch.cuda.get_device_capability()}.') + device_capability = torch.cuda.get_device_capability() + if not device_capability[0] >= 8: + raise ValueError(f'Can not use Marlin int4*fp16 kernel with a device of compute capability {device_capability}.') if infeatures % 128 != 0 or outfeatures % 256 != 0: raise ValueError("`infeatures` must be divisible by 128 and `outfeatures` by 256.") if bits not in [4]: From d418095e63198a455553d53f5c381c1200df6047 Mon Sep 17 00:00:00 2001 From: cyang49 <7364402+cyang49@users.noreply.github.com> Date: Sun, 17 Mar 2024 11:16:50 -0400 Subject: [PATCH 4/5] Apply suggestion on GPTQ buffer setup Signed-off-by: cyang49 <7364402+cyang49@users.noreply.github.com> Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> --- server/text_generation_server/server.py | 48 ++++++++++++------------- 1 file changed, 23 insertions(+), 25 deletions(-) diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 7d295caa..aefa00ac 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -278,31 +278,29 @@ async def serve_inner( if quantize == "gptq" and deployment_framework == "tgis_native": from text_generation_server.utils.layers import HAS_GPTQ_CUDA, EXLLAMA_VERSION - if HAS_GPTQ_CUDA: - if EXLLAMA_VERSION is not None: - try: - # When using GPTQ, Exllama kernels need some global kernels - # For which we have the final shapes only after the model has loaded - # This will allocate those buffers. - - if EXLLAMA_VERSION == "1": - from text_generation_server.utils.gptq.exllama import ( - create_exllama_buffers, set_device, - ) - set_device(device) - create_exllama_buffers(max_sequence_length) - else: - assert EXLLAMA_VERSION == "2" - from text_generation_server.utils.gptq.exllamav2 import ( - set_device, Ex4bitLinearV2, - ) - set_device(device) - for _, submodule in model.model.named_modules(): - if isinstance(submodule, Ex4bitLinearV2): - submodule.post_init() # make q matrix and set scratch space - - except ImportError: - print("WARN: Error setting up GPTQ exllama buffers") + if HAS_GPTQ_CUDA and EXLLAMA_VERSION is not None: + try: + # When using GPTQ, Exllama kernels need some global kernels + # For which we have the final shapes only after the model has loaded + # This will allocate those buffers. + if EXLLAMA_VERSION == "1": + from text_generation_server.utils.gptq.exllama import ( + create_exllama_buffers, set_device, + ) + set_device(device) + create_exllama_buffers(max_sequence_length) + elif EXLLAMA_VERSION == "2": + from text_generation_server.utils.gptq.exllamav2 import ( + set_device, Ex4bitLinearV2, + ) + set_device(device) + for _, submodule in model.model.named_modules(): + if isinstance(submodule, Ex4bitLinearV2): + submodule.post_init() # make q matrix and set scratch space + else: + raise ValueError(f"Unsupported {EXLLAMA_VERSION=}") + except ImportError: + print("WARN: Error setting up GPTQ exllama buffers") if local_rank == 0 and device.type == "cuda": # Log GPU memory stats at startup From 57cba2d56a7e133a4722b12277ce198eb7a56ec3 Mon Sep 17 00:00:00 2001 From: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Date: Mon, 25 Mar 2024 13:50:37 +0000 Subject: [PATCH 5/5] changing default to exllama Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> --- server/text_generation_server/utils/layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index f6a78a1e..312f4d5d 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -15,7 +15,7 @@ HAS_BITS_AND_BYTES = False EXLLAMA_VERSION = None HAS_GPTQ_CUDA = False -GPTQ_CUDA_TYPE = os.getenv("GPTQ_CUDA_TYPE", "marlin").lower() +GPTQ_CUDA_TYPE = os.getenv("GPTQ_CUDA_TYPE", "exllama").lower() GPTQ_CUDA_LINEAR = None if torch.cuda.is_available():