diff --git a/auto_round/inference/backend.py b/auto_round/inference/backend.py index d3c5132ac..3d3e837f8 100644 --- a/auto_round/inference/backend.py +++ b/auto_round/inference/backend.py @@ -48,11 +48,11 @@ class BackendInfo: Attributes: device: A list of strings representing the devices the backend supports - (e.g., 'cuda', 'cpu'). + (e.g., 'cpu', 'xpu', 'cuda'). sym: A list of booleans indicating whether the backend supports symmetric quantization for weights (True if symmetric, False if not). packing_format: A list of strings representing the packing formats used by the backend - (e.g., 'triton', 'qbits'). + (e.g., 'ark', 'triton'). bits: A list of integers specifying the bit-widths supported by the backend for weight quantization (e.g., [2, 4, 8]). group_size: An optional list of integers specifying the group sizes supported @@ -437,51 +437,54 @@ def fp8_static_scheme_checker( requirements=["autoawq", "transformers"], ) -BackendInfos["qbits"] = BackendInfo( - device=["cpu"], - sym=[True, False], +BackendInfos["auto_round_kernel"] = BackendInfo( + device=["cpu", "xpu"], + sym=[True], packing_format=GPTQ_FORMAT_NO_ZP, bits=[2, 4, 8], group_size=None, - priority=1, + priority=6, checkers=[], - alias=["itrex", "qbits"], - compute_dtype=["float16", "bfloat16"], + alias=["ark"], + compute_dtype=["float32", "float16"], data_type=["int"], act_bits=WOQ_DEFAULT_ACT_BITS, - requirements=["torch<2.7.0", "intel-extension-for-transformers"], + requirements=["torch>=2.9.0"], + systems=["linux"], ) -BackendInfos["qbits_zp"] = BackendInfo( - device=["cpu"], - sym=[True, False], +BackendInfos["auto_round_kernel_zp"] = BackendInfo( + device=["cpu", "xpu"], + sym=[True], packing_format=GPTQ_FORMAT, bits=[2, 4, 8], group_size=None, - compute_dtype=["float16", "bfloat16"], + priority=6, + checkers=[], + alias=["ark"], + compute_dtype=["float32", "float16"], data_type=["int"], act_bits=WOQ_DEFAULT_ACT_BITS, - priority=1, - checkers=[], - alias=["itrex", "qbits"], - requirements=["torch<2.7.0", "intel-extension-for-transformers"], + requirements=["torch>=2.9.0"], + systems=["linux"], ) - -BackendInfos["qbits_awq"] = BackendInfo( +BackendInfos["auto_round_kernel_awq"] = BackendInfo( device=["cpu"], sym=[True, False], packing_format=AWQ_FORMAT, bits=[2, 4, 8], group_size=None, - compute_dtype=["float16", "bfloat16"], + priority=6, + checkers=[], + alias=["ark"], + compute_dtype=["float32", "float16"], data_type=["int"], act_bits=WOQ_DEFAULT_ACT_BITS, - priority=1, - checkers=[], - alias=["itrex", "qbits"], - requirements=["torch<2.7.0", "intel-extension-for-transformers"], + requirements=["torch>=2.9.0"], + systems=["linux"], ) + BackendInfos["ipex_gptq"] = BackendInfo( device=["cpu", "xpu"], sym=[True, False], @@ -613,12 +616,12 @@ def dynamic_import_inference_linear(backend, config): """Dynamically imports and returns the appropriate QuantLinear class based on the given backend. This function dynamically loads the correct `QuantLinear` class based on the backend and quantization - configuration (e.g., qbits, marlin, hpu, gptq, awq, auto_round). It imports specific modules or raises + configuration (e.g., ark, marlin, hpu, gptq, awq). It imports specific modules or raises errors if the required packages are not installed or the environment is not set up. Args: backend (str): - The backend to be used for quantization (e.g., 'qbits', 'marlin', 'hpu', 'gptq', 'awq', 'auto_round'). + The backend to be used for quantization (e.g., 'ark', 'marlin', 'hpu', 'gptq', 'awq'). config (QuantizationScheme): The quantization configuration containing parameters like bits, group_size, and sym. @@ -628,7 +631,7 @@ def dynamic_import_inference_linear(backend, config): Raises: ImportError: - If required modules are missing for a backend (e.g., Intel Extension, GPTQ, auto_awq). + If required modules are missing for a backend (e.g., ark, GPTQ, auto_awq). """ bits, group_size, sym = config["bits"], config["group_size"], config["sym"] @@ -641,26 +644,20 @@ def dynamic_import_inference_linear(backend, config): if "torch_nvfp4" in backend: return ar_qmodules.NVFP4QuantLinear - if "qbits" in backend: + if "auto_round_kernel" in backend or "ark" in backend: try: - from intel_extension_for_transformers import qbits # pylint: disable=E0401 + from auto_round_extension.ark import auto_round_kernel as ark # pylint: disable=E0611, E0401 except Exception as e: - raise ImportError( - "Please install Intel Extension for Transformers via 'pip install " - "intel-extension-for-transformers' to inference on X86 CPU" - ) - if "zp" in backend: - import auto_round_extension.qbits.qlinear_qbits_gptq as qlinear_qbits_gptq + raise ImportError("Please install auto_round_kernel version for CPU/XPU") + import auto_round_extension.ark.qlinear as qlinear - return qlinear_qbits_gptq.QuantLinear + if "zp" in backend: + return qlinear.QuantLinearGPTQ elif "awq" in backend: - import auto_round_extension.qbits.qbits_awq as qlinear_qbits_awq - - return qlinear_qbits_awq.QuantLinear + return qlinear.QuantLinearAWQ else: # auto_round must be at the end - import auto_round_extension.qbits.qlinear_qbits as qlinear_qbits_autoround + return qlinear.QuantLinear - return qlinear_qbits_autoround.QuantLinear if "ipex_gptq" in backend: from auto_round_extension.ipex.qlinear_ipex_gptq import QuantLinear diff --git a/auto_round/inference/convert_model.py b/auto_round/inference/convert_model.py index 7a55eda03..4d464763b 100644 --- a/auto_round/inference/convert_model.py +++ b/auto_round/inference/convert_model.py @@ -31,7 +31,6 @@ from auto_round.inference.utils import _expand_regex_config from auto_round.logger import logger from auto_round.schemes import QuantizationScheme -from auto_round.special_model_handler import _handle_moe_model from auto_round.utils import ( SUPPORTED_LAYER_TYPES, check_start_with_block_name, @@ -395,9 +394,9 @@ def _create_quant_layer(layer, layer_backend, config, in_features, out_features) bias = layer.bias is not None # Special handling for AWQ layers - from auto_round_extension.qbits.qbits_awq import QuantLinear as QBitsAWQQuantLinear + from auto_round_extension.ark.qlinear import QuantLinearAWQ - if "awq" in layer_backend and isinstance(QuantLinear, QBitsAWQQuantLinear): + if "awq" in layer_backend and isinstance(QuantLinear, QuantLinearAWQ): return QuantLinear.from_linear( layer, config["bits"], config["group_size"], init_only=True, has_zero_points=not config["sym"] ) @@ -474,7 +473,6 @@ def post_init(model: torch.nn.Module, used_backends: list[str]) -> None: need_gptqmodel_init = False need_ipex_itrex_init = False used_gptq_exllamav2 = False - # Determine which backends require post-init for backend in used_backends: if backend.startswith("auto_gptq"): @@ -483,7 +481,7 @@ def post_init(model: torch.nn.Module, used_backends: list[str]) -> None: used_gptq_exllamav2 = True elif backend.startswith("gptqmodel"): need_gptqmodel_init = True - elif backend.startswith(("ipex", "qbit")): + elif backend.startswith(("ipex", "auto_round_kernel")): need_ipex_itrex_init = True # AutoGPTQ post-init @@ -503,7 +501,7 @@ def post_init(model: torch.nn.Module, used_backends: list[str]) -> None: message = "repacking to CPU/XPU format" layers = [] ## ipex post_init will add one more layer for n, m in model.named_modules(): - if hasattr(m, "QUANT_TYPE") and ("qbits" in m.QUANT_TYPE or "ipex" in m.QUANT_TYPE): + if hasattr(m, "QUANT_TYPE") and ("ark" in m.QUANT_TYPE or "ipex" in m.QUANT_TYPE): layers.append(m) for layer in tqdm(layers, desc=message, total=len(layers), leave=True): @@ -583,9 +581,6 @@ def convert_hf_model(model: nn.Module, target_device: str = "cpu") -> tuple[nn.M elif packing_format == "auto_round:gptq": packing_format = "auto_round:auto_gptq" - # Preprocess model before replace layers - model = _handle_moe_model(model) - # Replace layers with quantized versions layer_configs = get_layer_config(model, quantization_config) used_backends = _replace_by_quant_layers(model, layer_configs, backend, target_device, packing_format) diff --git a/auto_round_extension/qbits/__init__.py b/auto_round_extension/ark/__init__.py similarity index 60% rename from auto_round_extension/qbits/__init__.py rename to auto_round_extension/ark/__init__.py index 04cea3305..76e0da18c 100644 --- a/auto_round_extension/qbits/__init__.py +++ b/auto_round_extension/ark/__init__.py @@ -12,12 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from auto_round_extension.qbits.qlinear_qbits import QuantLinear as QBitsQuantLinear -from auto_round_extension.qbits.qlinear_qbits_gptq import ( - QuantLinear as QBitsGPTQQuantLinear, -) -from auto_round_extension.qbits.qbits_awq import QuantLinear as QBitsAWQQuantLinear +from auto_round_extension.ark.qlinear import QuantLinear, QuantLinearGPTQ, QuantLinearAWQ -qbits_qlinear_classes = (QBitsQuantLinear, QBitsGPTQQuantLinear) +qlinear_classes = (QuantLinear, QuantLinearGPTQ) -qbits_awq_classes = (QBitsAWQQuantLinear,) +awq_classes = (QuantLinearAWQ,) diff --git a/auto_round_extension/ark/qlinear.py b/auto_round_extension/ark/qlinear.py new file mode 100644 index 000000000..eabe56248 --- /dev/null +++ b/auto_round_extension/ark/qlinear.py @@ -0,0 +1,420 @@ +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +import torch.nn as nn + +from auto_round.utils import convert_dtype_torch2str, logger + +try: + from auto_round_extension.ark import auto_round_kernel as ark + + ARK_INSTALLED = True +except: + ARK_INSTALLED = False + +BITS_DTYPE_MAPPING = { + 2: "int2", + 4: "int4", + 8: "int8", +} + +AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7] + + +def unpack_awq(qweight: torch.Tensor, qzeros: torch.Tensor, bits: int): + shifts = torch.arange(0, 32, bits, device="cpu") + + # unpacking columnwise + iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to( + torch.int8 # smallest dtype available + ) + iweights = iweights.view(iweights.shape[0], -1) + + # unpacking columnwise + if qzeros is not None: + izeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to( + torch.int8 # smallest dtype available + ) + izeros = izeros.view(izeros.shape[0], -1) + else: + izeros = qzeros + + return iweights, izeros + + +def reverse_awq_order(iweights: torch.Tensor, izeros: torch.Tensor, bits: int): + reverse_order_tensor = torch.arange( + iweights.shape[-1], + dtype=torch.int32, + device="cpu", + ) + reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits) + reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER] + reverse_order_tensor = reverse_order_tensor.view(-1) + + if izeros is not None: + izeros = izeros[:, reverse_order_tensor] + iweights = iweights[:, reverse_order_tensor] + return iweights, izeros + + +class QuantLinearAWQ(nn.Module): + QUANT_TYPE = "ark_awq" + + def __init__(self, w_bit, group_size, in_features, out_features, bias, zero_point, dev): + super().__init__() + assert ARK_INSTALLED, "Please install auto_round_kernel package." + + self.use_bf16 = ark.check_isa_supported("AMX") + + if w_bit not in [2, 3, 4, 8]: + raise NotImplementedError("Only 2, 3, 4, 8 bits are supported for now.") + + self.in_features = in_features + self.out_features = out_features + self.w_bit = w_bit + self.group_size = group_size if group_size != -1 else in_features + self.zero_point = zero_point + self.scale_dtype = torch.float32 + + # quick sanity check (make sure alignment) + assert self.in_features % self.group_size == 0 + assert out_features % (32 // self.w_bit) == 0 + self.pack_num = 32 // self.w_bit + self.register_buffer( + "qzeros", + torch.zeros( + (in_features // self.group_size, out_features // self.pack_num), + dtype=torch.int8, + device=dev, + ), + ) + self.register_buffer( + "scales", + torch.zeros( + (in_features // self.group_size, out_features), + dtype=torch.bfloat16 if self.use_bf16 else torch.float32, + device=dev, + ), + ) + if bias: + self.register_buffer( + "bias", + torch.zeros((out_features), dtype=torch.bfloat16 if self.use_bf16 else torch.float32, device=dev), + ) + else: + self.register_buffer( + "bias", + None, + ) + qweight = torch.zeros((in_features, out_features // self.pack_num), dtype=torch.int32, device=dev) + self.register_buffer("qweight", qweight) + + def post_init(self): + assert self.qweight.device.type == "cpu" + + intweight, zeros = unpack_awq(self.qweight, self.qzeros, self.w_bit) # weight: k x n zeros: k / group_size x n + intweight, zeros = reverse_awq_order(intweight, zeros, self.w_bit) # weight: k x n zeros: k / group_size x n + if self.zero_point: ## asym has accuracy issue, have not root caused yet + intweight = torch.bitwise_and(intweight, (2**self.w_bit) - 1) - (2 ** (self.w_bit - 1)) + zeros = torch.bitwise_and(zeros, (2**self.w_bit) - 1) - (2 ** (self.w_bit - 1)) + else: + ##symmetric, our default zp is 8 + intweight = torch.bitwise_and(intweight, (2**self.w_bit) - 1) - (2 ** (self.w_bit - 1)) + g_idx = torch.empty(0, dtype=torch.int32) + self.qweight = ark.repack_quantized_weight( + intweight, + self.scales.float(), + zeros, + g_idx, + BITS_DTYPE_MAPPING[self.w_bit], + convert_dtype_torch2str(self.scale_dtype), + convert_dtype_torch2str(self.scales.dtype), + self.zero_point, + self.group_size, + ) + + @classmethod + def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None, has_zero_points=False): + awq_linear = cls( + w_bit, + group_size, + linear.in_features, + linear.out_features, + linear.bias is not None, + has_zero_points, + linear.weight.device, + ) + if init_only: # just prepare for loading sd + return awq_linear + + raise NotImplementedError("Only inference is supported for Exllama kernels") + + @torch.no_grad() + def forward(self, x): + assert ARK_INSTALLED, "ARK kernels could not be loaded. " + + input_dtype = x.dtype + out_shape = x.shape[:-1] + (self.out_features,) + x = x.view(-1, x.shape[-1]) # convert xd to 2d + out_2d_shape = x.shape[:-1] + (self.out_features,) + + outputs = torch.zeros(out_2d_shape, dtype=input_dtype) + bias = ( + self.bias + if self.bias is not None + else torch.empty(0, dtype=torch.bfloat16 if self.use_bf16 else torch.float32) + ) + + ark.woq_linear( + x, + self.qweight, + bias, + outputs, + convert_dtype_torch2str(input_dtype), + BITS_DTYPE_MAPPING[self.w_bit], + convert_dtype_torch2str(self.scale_dtype), + True, + ) + + return outputs.view(out_shape) + + def extra_repr(self) -> str: + return "in_features={}, out_features={}, bias={}, w_bit={}, group_size={}".format( + self.in_features, + self.out_features, + self.bias is not None, + self.w_bit, + self.group_size, + ) + + +class QuantLinear(nn.Module): + QUANT_TYPE = "ark_gptq_nozp" + ZP_BIAS = 0 + + def __init__( + self, + bits, + group_size, + infeatures, + outfeatures, + bias, + kernel_switch_threshold=128, + trainable=False, + weight_dtype=torch.bfloat16, + **kwargs, + ): + super().__init__() + + if bits not in [2, 4, 8]: + raise NotImplementedError("Only 2, 4,8 bits are supported for ARK.") + assert ARK_INSTALLED, "Please install auto_round_kernel." + + self.infeatures = infeatures + self.outfeatures = outfeatures + self.bits = bits + self.group_size = group_size if group_size != -1 else infeatures + self.maxq = 2**self.bits - 1 + self.weight_dtype = weight_dtype + self.asym = True + ark.set_threads(torch.get_num_threads()) + self.register_buffer( + "qweight", + torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32), + ) + self.register_buffer( + "qzeros", + torch.zeros( + ( + math.ceil(infeatures / self.group_size), + outfeatures // 32 * self.bits, + ), + dtype=torch.int32, + ), + ) + self.register_buffer( + "scales", + torch.zeros( + (math.ceil(infeatures / self.group_size), outfeatures), + dtype=weight_dtype, + ), + ) + if bias: + self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.float)) + else: + self.bias = None + + self.kernel_switch_threshold = kernel_switch_threshold + self.trainable = trainable + + def post_init(self): + assert self.qweight.device.type in ["cpu", "xpu"] + # intweight: k x n, zeros: k / group_size x n + intweight, zeros = unpack_to_8bit_signed(self.qweight, self.qzeros, self.bits, self.ZP_BIAS) + if zeros is None: + zeros = torch.empty(0, dtype=torch.int8) + self.asym = False + else: + # change it to int8 with offset 128 + if self.bits == 8: + zeros = (zeros.to(torch.int32) - (2 ** (self.bits - 1))).to(torch.int8) + else: + zeros -= 2 ** (self.bits - 1) + if self.qweight.device.type != "cpu": + assert not self.asym + if not self.asym: + intweight -= 2 ** (self.bits - 1) + intweight = intweight.to(torch.uint8 if self.asym else torch.int8) + # due to asym return torch.uint8 but backend request int8, + # change it to int8 with offset 128 + if self.asym: + intweight = (intweight.to(torch.int32) - (2 ** (self.bits - 1))).to(torch.int8) + + logger.debug( + f"ARK repack quantized weight: K:{intweight.shape[0]}, N:{intweight.shape[1]}, weight_dtype:{BITS_DTYPE_MAPPING[self.bits]}, scale_dtype:fp32, compute_dtype:fp32, group_size:{self.group_size}" + ) + + if self.qweight.device.type == "xpu": + self.sdt = "fp16" + self.cdt = "fp16" + scales = self.scales.to(torch.float16).contiguous() + else: + self.sdt = "fp32" + self.cdt = "fp32" + scales = self.scales.float().contiguous() + self.wdt = BITS_DTYPE_MAPPING[self.bits] + + self.qweight = ark.repack_quantized_weight( + intweight.contiguous(), + scales, + zeros.contiguous(), + torch.Tensor(), + # compute_dtype + self.cdt, + # weight_dtype + self.wdt, + # scale_dtype + self.sdt, + self.asym, + self.group_size, + ) + + # self.revert_wei = torch.zeros(self.infeatures, self.outfeatures, dtype=scales.dtype, device=self.qweight.device) + # # print(packw, packw.device, packw.dtype) + # ark.dequantize_packed_weight( + # self.qweight, self.revert_wei, False, self.cdt, self.wdt, self.sdt, self.group_size, self.outfeatures, self.infeatures) + # free mem + self.qzeros = torch.empty(0) + self.scales = torch.empty(0) + if self.bias is not None: + if self.bias.device.type == "cpu": + self.bias = self.bias.to(torch.float32) + else: + self.bias = self.bias.to(torch.float16) + + def forward(self, x: torch.Tensor): + raw_input_dtype = x.dtype + if x.device.type == "cpu": + odt = torch.float32 + if raw_input_dtype != torch.float32: + x = x.to(torch.float32) + else: + odt = x.dtype + + out_shape = x.shape[:-1] + (self.outfeatures,) + x = x.view(-1, x.shape[-1]) # convert xd to 2d + out_2d_shape = x.shape[:-1] + (self.outfeatures,) + outputs = torch.empty(out_2d_shape, device=x.device, dtype=odt) + bias = self.bias if self.bias is not None else torch.empty(0, dtype=torch.float) + + ark.woq_linear( + x, + self.qweight, + bias, + outputs, + self.cdt, # compute_dtype + self.wdt, # weight_dtype + self.sdt, # scale_dtype + self.asym, + self.group_size, + ) + if x.device.type == "xpu": + outputs = outputs + bias + return outputs.to(raw_input_dtype).view(out_shape) + + +class QuantLinearGPTQ(QuantLinear): + QUANT_TYPE = "ark_gptq" + ZP_BIAS = 1 + + +@torch.no_grad() +def unpack_to_8bit_signed(qweight, qzeros, bits, gptq_bias=1): + wf = torch.tensor(list(range(0, 32, bits)), dtype=torch.int32, device=qweight.device).unsqueeze(0) + zeros = None + if not torch.all(torch.eq(qzeros, 2004318071 if bits == 4 else 0b01111111011111110111111101111111)): + zp_shape = list(qzeros.shape) + zp_shape[1] = zp_shape[1] * (32 // bits) + + zeros = torch.bitwise_right_shift(torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits), wf.unsqueeze(0)).to( + torch.int16 if bits == 8 else torch.int8 + ) + torch.bitwise_and(zeros, (2**bits) - 1, out=zeros) + if bits == 8: + zeros = zeros.to(torch.uint8) + zeros += gptq_bias + try: + zeros = zeros.reshape(zp_shape) + except: + # zeros and scales have different iteam numbers. + # remove 1 (due to 0 + 1 in line 252) + zeros = zeros[zeros != 1] + zeros = zeros.reshape(zp_shape) + weight = torch.bitwise_right_shift(torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1), wf.unsqueeze(-1)).to( + torch.int16 if bits == 8 else torch.int8 + ) + weight.bitwise_and_((2**bits) - 1) + weight = weight.view(-1, weight.shape[-1]) + + return weight, zeros + + +# Copied from qlinear_marlin.py +@torch.no_grad() +def dequantize_weight(qweight, qzeros, scales, bits): + unpacked_qweight, unpacked_qzeros = unpack_to_8bit_signed(qweight, qzeros, bits) + group_size = unpacked_qweight.shape[0] // scales.shape[0] + scales = scales.repeat_interleave(group_size, dim=0) + if unpacked_qzeros is not None: + unpacked_qzeros = unpacked_qzeros.repeat_interleave(group_size, dim=0) + else: + unpacked_qzeros = torch.full_like(scales, 8 if bits == 4 else 128, dtype=torch.int32, device=qweight.device) + unpacked_qweight = (unpacked_qweight - unpacked_qzeros) * scales + + return unpacked_qweight, unpacked_qzeros + + +def ark_post_init(model): + for _, submodule in model.named_modules(): + if isinstance(submodule, QuantLinear): + submodule.post_init() + + return model + + +__all__ = ["QuantLinear", "QuantLinearGPTQ", "QuantLinearAWQ"] diff --git a/auto_round_extension/qbits/qbits_awq.py b/auto_round_extension/qbits/qbits_awq.py deleted file mode 100644 index d4b57505f..000000000 --- a/auto_round_extension/qbits/qbits_awq.py +++ /dev/null @@ -1,228 +0,0 @@ -# Copyright (c) 2025 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn as nn - -AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7] - - -def unpack_awq(qweight: torch.Tensor, qzeros: torch.Tensor, bits: int): - shifts = torch.arange(0, 32, bits, device="cpu") - - # unpacking columnwise - iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to( - torch.int8 # smallest dtype available - ) - iweights = iweights.view(iweights.shape[0], -1) - - # unpacking columnwise - if qzeros is not None: - izeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to( - torch.int8 # smallest dtype available - ) - izeros = izeros.view(izeros.shape[0], -1) - else: - izeros = qzeros - - return iweights, izeros - - -def reverse_awq_order(iweights: torch.Tensor, izeros: torch.Tensor, bits: int): - reverse_order_tensor = torch.arange( - iweights.shape[-1], - dtype=torch.int32, - device="cpu", - ) - reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits) - reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER] - reverse_order_tensor = reverse_order_tensor.view(-1) - - if izeros is not None: - izeros = izeros[:, reverse_order_tensor] - iweights = iweights[:, reverse_order_tensor] - return iweights, izeros - - -try: - from intel_extension_for_transformers import qbits # with QBits kernels () - - QBITS_INSTALLED = True -except: - QBITS_INSTALLED = False - -BITS_DTYPE_MAPPING = { - 4: "int4_clip", - 8: "int8", -} - - -def convert_dtype_torch2str(dtype): - if dtype == torch.int8: - return "int8" - elif dtype == torch.float: - return "fp32" - elif dtype == torch.float16: - return "fp16" - elif dtype == torch.bfloat16: - return "bf16" - elif isinstance(dtype, str) and dtype in ["int8", "fp32", "fp16", "bf16"]: - return dtype - else: - assert False, "Unsupported pytorch dtype {} to str dtype".format(dtype) - - -class QuantLinear(nn.Module): - QUANT_TYPE = "qbits_awq" - - def __init__(self, w_bit, group_size, in_features, out_features, bias, zero_point, dev): - super().__init__() - assert ( - QBITS_INSTALLED - ), "Please install ITREX qbits package with `pip install intel-extension-for-transformers`." - - self.use_bf16 = qbits.check_isa_supported("AMX") - - if w_bit not in [2, 3, 4, 8]: - raise NotImplementedError("Only 2, 3, 4, 8 bits are supported for now.") - - self.in_features = in_features - self.out_features = out_features - self.w_bit = w_bit - self.group_size = group_size if group_size != -1 else in_features - self.zero_point = zero_point - self.scale_dtype = torch.float32 - - # quick sanity check (make sure alignment) - assert self.in_features % self.group_size == 0 - assert out_features % (32 // self.w_bit) == 0 - self.pack_num = 32 // self.w_bit - self.register_buffer( - "qzeros", - torch.zeros( - (in_features // self.group_size, out_features // self.pack_num), - dtype=torch.int8, - device=dev, - ), - ) - self.register_buffer( - "scales", - torch.zeros( - (in_features // self.group_size, out_features), - dtype=torch.bfloat16 if self.use_bf16 else torch.float32, - device=dev, - ), - ) - if bias: - self.register_buffer( - "bias", - torch.zeros((out_features), dtype=torch.bfloat16 if self.use_bf16 else torch.float32, device=dev), - ) - else: - self.register_buffer( - "bias", - None, - ) - qweight = torch.zeros((in_features, out_features // self.pack_num), dtype=torch.int32, device=dev) - self.register_buffer("qweight", qweight) - - def post_init(self): - assert self.qweight.device.type == "cpu" - - intweight, zeros = unpack_awq(self.qweight, self.qzeros, self.w_bit) # weight: k x n zeros: k / group_size x n - intweight, zeros = reverse_awq_order(intweight, zeros, self.w_bit) # weight: k x n zeros: k / group_size x n - if self.zero_point: ## asym has accuracy issue, have not root caused yet - intweight = torch.bitwise_and(intweight, (2**self.w_bit) - 1) - (2 ** (self.w_bit - 1)) - zeros = torch.bitwise_and(zeros, (2**self.w_bit) - 1) - (2 ** (self.w_bit - 1)) - else: - ##symmetric, our default zp is 8 - intweight = torch.bitwise_and(intweight, (2**self.w_bit) - 1) - (2 ** (self.w_bit - 1)) - g_idx = torch.empty(0, dtype=torch.int32) - self.qweight = qbits.repack_quantized_weight( - intweight, - self.scales.float(), - zeros, - g_idx, - BITS_DTYPE_MAPPING[self.w_bit], - convert_dtype_torch2str(self.scale_dtype), - convert_dtype_torch2str(self.scales.dtype), - self.zero_point, - self.group_size, - ) - - @classmethod - def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None, has_zero_points=False): - awq_linear = cls( - w_bit, - group_size, - linear.in_features, - linear.out_features, - linear.bias is not None, - has_zero_points, - linear.weight.device, - ) - if init_only: # just prepare for loading sd - return awq_linear - - raise NotImplementedError("Only inference is supported for Exllama kernels") - - @torch.no_grad() - def forward(self, x): - assert QBITS_INSTALLED, ( - "QBits kernels could not be loaded. " - "Please install with `pip install intel-extension-for-transformers` and " - "refer to the detail https://github.com/intel/intel-extension-for-transformers/blob/main/docs/qbits.md" - ) - - input_dtype = x.dtype - out_shape = x.shape[:-1] + (self.out_features,) - x = x.view(-1, x.shape[-1]) # convert xd to 2d - out_2d_shape = x.shape[:-1] + (self.out_features,) - - outputs = torch.zeros(out_2d_shape, dtype=input_dtype) - bias = ( - self.bias - if self.bias is not None - else torch.empty(0, dtype=torch.bfloat16 if self.use_bf16 else torch.float32) - ) - - qbits.woq_linear( - x, - self.qweight, - bias, - outputs, - convert_dtype_torch2str(input_dtype), - BITS_DTYPE_MAPPING[self.w_bit], - convert_dtype_torch2str(self.scale_dtype), - True, - ) - - return outputs.view(out_shape) - - def extra_repr(self) -> str: - return "in_features={}, out_features={}, bias={}, w_bit={}, group_size={}".format( - self.in_features, - self.out_features, - self.bias is not None, - self.w_bit, - self.group_size, - ) - - -def qbits_post_init(model): - for _, submodule in model.named_modules(): - if isinstance(submodule, QuantLinear): - submodule.post_init() - - return model diff --git a/auto_round_extension/qbits/qlinear_qbits.py b/auto_round_extension/qbits/qlinear_qbits.py deleted file mode 100644 index 202f1eb19..000000000 --- a/auto_round_extension/qbits/qlinear_qbits.py +++ /dev/null @@ -1,230 +0,0 @@ -# Copyright (c) 2024 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math - -import numpy as np -import torch -import torch.nn as nn - -from auto_round.utils import convert_dtype_torch2str, logger - -QBITS_AVAILABLE = True - -BITS_DTYPE_MAPPING = { - 2: "int2_clip", - 4: "int4_clip", - 8: "int8", -} - - -class QuantLinear(nn.Module): - QUANT_TYPE = "qbits" - - def __init__( - self, - bits, - group_size, - infeatures, - outfeatures, - bias, - kernel_switch_threshold=128, - trainable=False, - weight_dtype=torch.bfloat16, - **kwargs, - ): - super().__init__() - - if bits not in [2, 4, 8]: - raise NotImplementedError("Only 2, 4,8 bits are supported for QBits.") - - self.infeatures = infeatures - self.outfeatures = outfeatures - self.bits = bits - self.group_size = group_size if group_size != -1 else infeatures - self.maxq = 2**self.bits - 1 - self.weight_dtype = weight_dtype - self.asym = True - self.qbits = None - - self.register_buffer( - "qweight", - torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32), - ) - self.register_buffer( - "qzeros", - torch.zeros( - ( - math.ceil(infeatures / self.group_size), - outfeatures // 32 * self.bits, - ), - dtype=torch.int32, - ), - ) - self.register_buffer( - "scales", - torch.zeros( - (math.ceil(infeatures / self.group_size), outfeatures), - dtype=weight_dtype, - ), - ) - if bias: - self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.float)) - else: - self.bias = None - - self.kernel_switch_threshold = kernel_switch_threshold - - self.trainable = trainable - - def req_check(self): - torch_version = str(torch.__version__) - if QBITS_AVAILABLE: - pass - # import intel_extension_for_transformers - # itrex_version = str(intel_extension_for_transformers.__version__) - # version_match_map = {"1.4": "2.2.0+cpu", - # "1.4.1": "2.2.0+cpu", "1.4.2": "2.3.0+cpu"} - # if itrex_version in version_match_map: - # if torch_version != version_match_map[itrex_version]: - # logger.warning( - # f"Please install torch {version_match_map[itrex_version]} by command 'pip install torch=={version_match_map[itrex_version]} --extra-index-url https://download.pytorch.org/whl/cpu' as Intel Extension for Transformers {itrex_version} is not compatible with current torch.") - else: - logger.error( - "Please install Intel Extension for Transformers by running 'pip install intel-extension-for-transformers' as qbits linear requirements checking fail. " - ) - exit(1) - - def post_init(self): - import intel_extension_for_transformers - - self.qbits = intel_extension_for_transformers.qbits - assert self.qweight.device.type == "cpu" - if self.bias is not None: - self.bias = self.bias.to(dtype=torch.float32) - - # intweight: k x n, zeros: k / group_size x n - intweight, zeros = unpack_to_8bit_signed(self.qweight, self.qzeros, self.bits) - if zeros is None: - zeros = torch.empty(0, dtype=torch.int8) - self.asym = False - else: - # change it to int8 with offset 128 - if self.bits == 8: - zeros = (zeros.to(torch.int32) - (2 ** (self.bits - 1))).to(torch.int8) - else: - zeros -= 2 ** (self.bits - 1) - - if not self.asym: - intweight -= 2 ** (self.bits - 1) - intweight = intweight.to(torch.uint8 if self.asym else torch.int8) - # due to asym return torch.uint8 but backend request int8, - # change it to int8 with offset 128 - if self.asym: - intweight = (intweight.to(torch.int32) - (2 ** (self.bits - 1))).to(torch.int8) - - scales = self.scales - - logger.debug( - f"QBits repack quantized weight: K:{intweight.shape[0]}, N:{intweight.shape[1]}, weight_dtype:{BITS_DTYPE_MAPPING[self.bits]}, scale_dtype:fp32, compute_dtype:fp32, group_size:{self.group_size}" - ) - self.qweight = self.qbits.repack_quantized_weight( - intweight.contiguous(), - scales.float().contiguous(), - zeros.contiguous(), - torch.empty(0), - # weight_dtype - BITS_DTYPE_MAPPING[self.bits], - # scale_dtype - "fp32", - # TODO(zhe): consider dynamic-set cmpt for better perf? - "fp32", - self.asym, - self.group_size, - ) - # free mem - self.qzeros = torch.empty(0) - self.scales = torch.empty(0) - - def forward(self, x: torch.Tensor): - raw_input_dtype = x.dtype - if raw_input_dtype != torch.float32: - x = x.to(torch.float32) - out_shape = x.shape[:-1] + (self.outfeatures,) - x = x.view(-1, x.shape[-1]) # convert xd to 2d - out_2d_shape = x.shape[:-1] + (self.outfeatures,) - - outputs = torch.zeros(out_2d_shape, device=x.device, dtype=torch.float) - bias = self.bias if self.bias is not None else torch.empty(0, dtype=torch.float) - - self.qbits.woq_linear( - x, - self.qweight, - bias, - outputs, - convert_dtype_torch2str(torch.float), # compute_dtype - BITS_DTYPE_MAPPING[self.bits], # weight_dtype - "fp32", # scale_dtype - self.asym, - ) - return outputs.to(raw_input_dtype).view(out_shape) - - -@torch.no_grad() -def unpack_to_8bit_signed(qweight, qzeros, bits): - wf = torch.tensor(list(range(0, 32, bits)), dtype=torch.int32).unsqueeze(0) - zeros = None - if not torch.all(torch.eq(qzeros, 2004318071 if bits == 4 else 0b01111111011111110111111101111111)): - zp_shape = list(qzeros.shape) - zp_shape[1] = zp_shape[1] * (32 // bits) - - zeros = torch.bitwise_right_shift(torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits), wf.unsqueeze(0)).to( - torch.int16 if bits == 8 else torch.int8 - ) - torch.bitwise_and(zeros, (2**bits) - 1, out=zeros) - if bits == 8: - zeros = zeros.to(torch.uint8) - try: - zeros = zeros.reshape(zp_shape) - except: - # zeros and scales have different iteam numbers. - # remove 1 (due to 0 + 1 in line 252) - zeros = zeros[zeros != 1] - zeros = zeros.reshape(zp_shape) - - weight = torch.bitwise_right_shift(torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1), wf.unsqueeze(-1)).to( - torch.int16 if bits == 8 else torch.int8 - ) - weight.bitwise_and_((2**bits) - 1) - weight = weight.view(-1, weight.shape[-1]) - - return weight, zeros - - -# Copied from qlinear_marlin.py -@torch.no_grad() -def dequantize_weight(qweight, qzeros, scales, bits): - unpacked_qweight, unpacked_qzeros = unpack_to_8bit_signed(qweight, qzeros, bits) - group_size = unpacked_qweight.shape[0] // scales.shape[0] - scales = scales.repeat_interleave(group_size, dim=0) - if unpacked_qzeros is not None: - unpacked_qzeros = unpacked_qzeros.repeat_interleave(group_size, dim=0) - else: - unpacked_qzeros = torch.full_like(scales, 8 if bits == 4 else 128, dtype=torch.int32) - unpacked_qweight = (unpacked_qweight - unpacked_qzeros) * scales - - return unpacked_qweight, unpacked_qzeros - - -__all__ = ["QuantLinear"] diff --git a/auto_round_extension/qbits/qlinear_qbits_gptq.py b/auto_round_extension/qbits/qlinear_qbits_gptq.py deleted file mode 100644 index c9c9e1fea..000000000 --- a/auto_round_extension/qbits/qlinear_qbits_gptq.py +++ /dev/null @@ -1,231 +0,0 @@ -# Copyright (c) 2024 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math - -import numpy as np -import torch -import torch.nn as nn - -from auto_round.utils import convert_dtype_torch2str, logger - -QBITS_AVAILABLE = True - -BITS_DTYPE_MAPPING = { - 2: "int2_clip", - 4: "int4_clip", - 8: "int8", -} - - -class QuantLinear(nn.Module): - QUANT_TYPE = "qbits_gptq" - - def __init__( - self, - bits, - group_size, - infeatures, - outfeatures, - bias, - kernel_switch_threshold=128, - trainable=False, - weight_dtype=torch.bfloat16, - **kwargs, - ): - super().__init__() - - if bits not in [2, 4, 8]: - raise NotImplementedError("Only 2, 4,8 bits are supported for QBits.") - - self.infeatures = infeatures - self.outfeatures = outfeatures - self.bits = bits - self.group_size = group_size if group_size != -1 else infeatures - self.maxq = 2**self.bits - 1 - self.weight_dtype = weight_dtype - self.asym = True - self.qbits = None - - self.register_buffer( - "qweight", - torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32), - ) - self.register_buffer( - "qzeros", - torch.zeros( - ( - math.ceil(infeatures / self.group_size), - outfeatures // 32 * self.bits, - ), - dtype=torch.int32, - ), - ) - self.register_buffer( - "scales", - torch.zeros( - (math.ceil(infeatures / self.group_size), outfeatures), - dtype=weight_dtype, - ), - ) - if bias: - self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.float)) - else: - self.bias = None - - self.kernel_switch_threshold = kernel_switch_threshold - - self.trainable = trainable - - def req_check(self): - torch_version = str(torch.__version__) - if QBITS_AVAILABLE: - pass - # import intel_extension_for_transformers - # itrex_version = str(intel_extension_for_transformers.__version__) - # version_match_map = {"1.4": "2.2.0+cpu", - # "1.4.1": "2.2.0+cpu", "1.4.2": "2.3.0+cpu"} - # if itrex_version in version_match_map: - # if torch_version != version_match_map[itrex_version]: - # logger.warning( - # f"Please install torch {version_match_map[itrex_version]} by command 'pip install torch=={version_match_map[itrex_version]} --extra-index-url https://download.pytorch.org/whl/cpu' as Intel Extension for Transformers {itrex_version} is not compatible with current torch.") - else: - logger.error( - "Please install Intel Extension for Transformers by running 'pip install intel-extension-for-transformers' as qbits linear requirements checking fail. " - ) - exit(1) - - def post_init(self): - import intel_extension_for_transformers - - self.qbits = intel_extension_for_transformers.qbits - assert self.qweight.device.type == "cpu" - if self.bias is not None: - self.bias = self.bias.to(dtype=torch.float32) - - # intweight: k x n, zeros: k / group_size x n - intweight, zeros = unpack_to_8bit_signed(self.qweight, self.qzeros, self.bits) - if zeros is None: - zeros = torch.empty(0, dtype=torch.int8) - self.asym = False - else: - # change it to int8 with offset 128 - if self.bits == 8: - zeros = (zeros.to(torch.int32) - (2 ** (self.bits - 1))).to(torch.int8) - else: - zeros -= 2 ** (self.bits - 1) - - if not self.asym: - intweight -= 2 ** (self.bits - 1) - intweight = intweight.to(torch.uint8 if self.asym else torch.int8) - # due to asym return torch.uint8 but backend request int8, - # change it to int8 with offset 128 - if self.asym: - intweight = (intweight.to(torch.int32) - (2 ** (self.bits - 1))).to(torch.int8) - - scales = self.scales - - logger.debug( - f"QBits repack quantized weight: K:{intweight.shape[0]}, N:{intweight.shape[1]}, weight_dtype:{BITS_DTYPE_MAPPING[self.bits]}, scale_dtype:fp32, compute_dtype:fp32, group_size:{self.group_size}" - ) - self.qweight = self.qbits.repack_quantized_weight( - intweight.contiguous(), - scales.float().contiguous(), - zeros.contiguous(), - torch.empty(0), - # weight_dtype - BITS_DTYPE_MAPPING[self.bits], - # scale_dtype - "fp32", - # TODO(zhe): consider dynamic-set cmpt for better perf? - "fp32", - self.asym, - self.group_size, - ) - # free mem - self.qzeros = torch.empty(0) - self.scales = torch.empty(0) - - def forward(self, x: torch.Tensor): - raw_input_dtype = x.dtype - if raw_input_dtype != torch.float32: - x = x.to(torch.float32) - out_shape = x.shape[:-1] + (self.outfeatures,) - x = x.view(-1, x.shape[-1]) # convert xd to 2d - out_2d_shape = x.shape[:-1] + (self.outfeatures,) - - outputs = torch.zeros(out_2d_shape, device=x.device, dtype=torch.float) - bias = self.bias if self.bias is not None else torch.empty(0, dtype=torch.float) - - self.qbits.woq_linear( - x, - self.qweight, - bias, - outputs, - convert_dtype_torch2str(torch.float), # compute_dtype - BITS_DTYPE_MAPPING[self.bits], # weight_dtype - "fp32", # scale_dtype - self.asym, - ) - return outputs.to(raw_input_dtype).view(out_shape) - - -@torch.no_grad() -def unpack_to_8bit_signed(qweight, qzeros, bits): - wf = torch.tensor(list(range(0, 32, bits)), dtype=torch.int32).unsqueeze(0) - zeros = None - if not torch.all(torch.eq(qzeros, 2004318071 if bits == 4 else 0b01111111011111110111111101111111)): - zp_shape = list(qzeros.shape) - zp_shape[1] = zp_shape[1] * (32 // bits) - - zeros = torch.bitwise_right_shift(torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits), wf.unsqueeze(0)).to( - torch.int16 if bits == 8 else torch.int8 - ) - torch.bitwise_and(zeros, (2**bits) - 1, out=zeros) - if bits == 8: - zeros = zeros.to(torch.uint8) - zeros += 1 - try: - zeros = zeros.reshape(zp_shape) - except: - # zeros and scales have different iteam numbers. - # remove 1 (due to 0 + 1 in line 252) - zeros = zeros[zeros != 1] - zeros = zeros.reshape(zp_shape) - - weight = torch.bitwise_right_shift(torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1), wf.unsqueeze(-1)).to( - torch.int16 if bits == 8 else torch.int8 - ) - weight.bitwise_and_((2**bits) - 1) - weight = weight.view(-1, weight.shape[-1]) - - return weight, zeros - - -# Copied from qlinear_marlin.py -@torch.no_grad() -def dequantize_weight(qweight, qzeros, scales, bits): - unpacked_qweight, unpacked_qzeros = unpack_to_8bit_signed(qweight, qzeros, bits) - group_size = unpacked_qweight.shape[0] // scales.shape[0] - scales = scales.repeat_interleave(group_size, dim=0) - if unpacked_qzeros is not None: - unpacked_qzeros = unpacked_qzeros.repeat_interleave(group_size, dim=0) - else: - unpacked_qzeros = torch.full_like(scales, 8 if bits == 4 else 128, dtype=torch.int32) - unpacked_qweight = (unpacked_qweight - unpacked_qzeros) * scales - - return unpacked_qweight, unpacked_qzeros - - -__all__ = ["QuantLinear"] diff --git a/test/test_ark/test_model.py b/test/test_ark/test_model.py new file mode 100644 index 000000000..911a186c0 --- /dev/null +++ b/test/test_ark/test_model.py @@ -0,0 +1,127 @@ +import shutil +import sys +import unittest + +import pytest + +sys.path.insert(0, "../..") + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from auto_round import AutoRound, AutoRoundConfig +from auto_round.eval.evaluation import simple_evaluate_user_model +from auto_round.testing_utils import require_autogptq, require_gptqmodel + + +class LLMDataLoader: + def __init__(self): + self.batch_size = 1 + + def __iter__(self): + for i in range(2): + yield torch.ones([1, 10], dtype=torch.long) + + +class TestAutoRoundTorchBackend(unittest.TestCase): + + @classmethod + def setUpClass(self): + self.model_name = "facebook/opt-125m" + self.save_folder = "./saved" + self.llm_dataloader = LLMDataLoader() + + def model_infer(self, model, tokenizer): + prompts = [ + "Hello,my name is", + # "The president of the United States is", + # "The capital of France is", + # "The future of AI is", + ] + + inputs = tokenizer(prompts, return_tensors="pt", padding=False, truncation=True) + + outputs = model.generate( + input_ids=inputs["input_ids"].to(model.device), + attention_mask=inputs["attention_mask"].to(model.device), + do_sample=False, ## change this to follow official usage + max_new_tokens=5, + ) + generated_ids = [output_ids[len(input_ids) :] for input_ids, output_ids in zip(inputs["input_ids"], outputs)] + + decoded_outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + + for i, prompt in enumerate(prompts): + print(f"Prompt: {prompt}") + print(f"Generated: {decoded_outputs[i]}") + print("-" * 50) + return decoded_outputs[0] + + @classmethod + def tearDownClass(self): + shutil.rmtree(self.save_folder, ignore_errors=True) + shutil.rmtree("runs", ignore_errors=True) + + def test_torch_4bits_sym_cpu(self): + model = AutoModelForCausalLM.from_pretrained(self.model_name, dtype="auto", trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True) + bits, group_size, sym = 4, 32, True + autoround = AutoRound( + model, + tokenizer, + bits=bits, + group_size=group_size, + sym=sym, + iters=0, + seqlen=2, + dataset=self.llm_dataloader, + ) + quantized_model_path = self.save_folder + autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round:gptqmodel") + + quantization_config = AutoRoundConfig(backend="ark") + model = AutoModelForCausalLM.from_pretrained( + quantized_model_path, dtype=torch.float16, device_map="cpu", quantization_config=quantization_config + ) + + tokenizer = AutoTokenizer.from_pretrained(self.save_folder) + self.model_infer(model, tokenizer) + result = simple_evaluate_user_model(model, tokenizer, batch_size=32, tasks="lambada_openai", limit=1000) + print(result["results"]["lambada_openai"]["acc,none"]) + self.assertGreater(result["results"]["lambada_openai"]["acc,none"], 0.28) + + shutil.rmtree("./saved", ignore_errors=True) + + def test_torch_4bits_sym_xpu(self): + model = AutoModelForCausalLM.from_pretrained(self.model_name, dtype="auto", trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True) + bits, group_size, sym = 4, 32, True + autoround = AutoRound( + model, + tokenizer, + bits=bits, + group_size=group_size, + sym=sym, + iters=0, + seqlen=2, + dataset=self.llm_dataloader, + ) + quantized_model_path = self.save_folder + autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") ##will convert to gptq model + + quantization_config = AutoRoundConfig(backend="ark") + model = AutoModelForCausalLM.from_pretrained( + quantized_model_path, dtype=torch.float16, device_map="xpu", quantization_config=quantization_config + ) + + tokenizer = AutoTokenizer.from_pretrained(self.save_folder) + self.model_infer(model, tokenizer) + result = simple_evaluate_user_model(model, tokenizer, batch_size=32, tasks="lambada_openai", limit=1000) + print(result["results"]["lambada_openai"]["acc,none"]) + self.assertGreater(result["results"]["lambada_openai"]["acc,none"], 0.28) + torch.xpu.empty_cache() + shutil.rmtree(self.save_folder, ignore_errors=True) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_cpu/test_torch_backend.py b/test/test_cpu/test_torch_backend.py index 557bf1f38..d1e9bd293 100644 --- a/test/test_cpu/test_torch_backend.py +++ b/test/test_cpu/test_torch_backend.py @@ -27,7 +27,7 @@ class TestAutoRoundTorchBackend(unittest.TestCase): @classmethod def setUpClass(self): - self.model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + self.model_name = "facebook/opt-125m" self.save_folder = "./saved" self.llm_dataloader = LLMDataLoader() @@ -63,7 +63,7 @@ def tearDownClass(self): shutil.rmtree("runs", ignore_errors=True) def test_torch_4bits_asym(self): - model = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype="auto", trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained(self.model_name, dtype="auto", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True) bits, group_size, sym = 4, 128, False autoround = AutoRound( @@ -81,7 +81,7 @@ def test_torch_4bits_asym(self): quantization_config = AutoRoundConfig(backend="torch") model = AutoModelForCausalLM.from_pretrained( - quantized_model_path, torch_dtype=torch.float16, device_map="cpu", quantization_config=quantization_config + quantized_model_path, dtype=torch.float16, device_map="cpu", quantization_config=quantization_config ) tokenizer = AutoTokenizer.from_pretrained(self.save_folder) @@ -92,7 +92,7 @@ def test_torch_4bits_asym(self): torch.cuda.empty_cache() model = AutoModelForCausalLM.from_pretrained( - self.save_folder, torch_dtype=torch.bfloat16, device_map="cpu", quantization_config=quantization_config + self.save_folder, dtype=torch.bfloat16, device_map="cpu", quantization_config=quantization_config ) tokenizer = AutoTokenizer.from_pretrained(self.save_folder) @@ -104,9 +104,9 @@ def test_torch_4bits_asym(self): shutil.rmtree("./saved", ignore_errors=True) def test_torch_4bits_sym(self): - model = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype="auto", trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained(self.model_name, dtype="auto", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True) - bits, group_size, sym = 4, 128, True + bits, group_size, sym = 4, 32, True autoround = AutoRound( model, tokenizer, @@ -122,12 +122,12 @@ def test_torch_4bits_sym(self): quantization_config = AutoRoundConfig(backend="torch") model = AutoModelForCausalLM.from_pretrained( - quantized_model_path, torch_dtype=torch.float16, device_map="cpu", quantization_config=quantization_config + quantized_model_path, dtype=torch.float16, device_map="auto", quantization_config=quantization_config ) tokenizer = AutoTokenizer.from_pretrained(self.save_folder) self.model_infer(model, tokenizer) - result = simple_evaluate_user_model(model, tokenizer, batch_size=16, tasks="lambada_openai", limit=10) + result = simple_evaluate_user_model(model, tokenizer, batch_size=32, tasks="lambada_openai", limit=1000) print(result["results"]["lambada_openai"]["acc,none"]) self.assertGreater(result["results"]["lambada_openai"]["acc,none"], 0.28) torch.cuda.empty_cache()