diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index d56178ad..6aa4f22d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -222,12 +222,22 @@ def __init__( weights=weights, bias=False, ) + + noshard_o_proj = False + if config.quantize == 'gptq': + from text_generation_server.utils.layers import IS_TP_AWARE_GPTQ + noshard_o_proj = IS_TP_AWARE_GPTQ + self.o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", weights=weights, bias=False, + noshard=noshard_o_proj, # Don't shard o_proj weight matrix if TP-aware optimization is desired ) + self.noshard_o_proj = noshard_o_proj + self.world_size = weights.process_group.size() + self.rank = weights.process_group.rank() def forward( self, @@ -285,9 +295,19 @@ def forward( 1, False, ) + attn_output = attn_output.reshape(-1, self.num_heads * self.head_size) - return self.o_proj(attn_output.reshape(-1, self.num_heads * self.head_size)) + # TP-aware Masked Matmul Optimization by zero filling the activation + # and multiply with full weight matrix in o_proj + if self.noshard_o_proj: + shard_size = attn_output.shape[1] + # assert shard_size*self.world_size == self.o_proj.linear.height + zf_attn_output = torch.zeros((attn_output.shape[0], shard_size*self.world_size), dtype=attn_output.dtype, device=attn_output.device) + start_idx = self.rank * shard_size + zf_attn_output[:, start_idx:start_idx+shard_size] = attn_output + attn_output = zf_attn_output + return self.o_proj(attn_output) class LlamaMLP(nn.Module): def __init__(self, prefix, config, weights): @@ -303,6 +323,17 @@ def __init__(self, prefix, config, weights): else "none", ) ) + + # For TP-aware preshuffle optimization, load the g_idx of down_proj for computing perm + # When perm==None the original unoptimized control path is taken + perm = None + if config.quantize=="gptq": + from text_generation_server.utils.layers import IS_TP_AWARE_GPTQ + if IS_TP_AWARE_GPTQ: + down_proj_g_idx = weights.get_tensor(f"{prefix}.down_proj.g_idx") + if down_proj_g_idx is not None: + perm = torch.argsort(down_proj_g_idx) + # Fuse gate and up proj self.gate_up_proj = TensorParallelColumnLinear.load_multi( config, @@ -310,12 +341,14 @@ def __init__(self, prefix, config, weights): weights=weights, dim=0, bias=False, + col_perm=perm, ) self.down_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.down_proj", weights=weights, bias=False, + row_perm=perm, ) self.intermediate_size = ( config.intermediate_size // weights.process_group.size() diff --git a/server/text_generation_server/utils/gptq/shuffle.py b/server/text_generation_server/utils/gptq/shuffle.py new file mode 100644 index 00000000..a21666e4 --- /dev/null +++ b/server/text_generation_server/utils/gptq/shuffle.py @@ -0,0 +1,65 @@ +import torch + +# Shuffle columns of scales +def shuffle_and_replace_scales(state_dict, scales_name, col_perm): + scales = state_dict[scales_name] + assert len(col_perm) == scales.shape[1] + + shuffled_scales = scales[:,col_perm] + state_dict[scales_name] = shuffled_scales + +def unpack_shuffle_repack_and_replace_qzeros(state_dict, bits, qzeros_name, col_perm): + qzeros = state_dict[qzeros_name] + mask = 2**bits - 1 + pack_size = 32 // bits + assert len(col_perm) == qzeros.shape[1] * pack_size + + #unpack + unpacked_qzeros = torch.zeros((qzeros.shape[0], qzeros.shape[1]*pack_size), dtype=torch.int) + for i in range(pack_size): + unpacked_qzeros[:, i::pack_size] = (qzeros >> (i*bits)) & (mask) + + # shuffle + shuffled_qzeros = unpacked_qzeros[:,col_perm] + + # repack + packed_qzeros = torch.zeros_like(qzeros) + for i in range(pack_size): + packed_qzeros |= (shuffled_qzeros[:, i::pack_size] & mask) << (i*bits) + + state_dict[qzeros_name] = packed_qzeros + +def shuffle_and_replace_qweight(state_dict, bits, group_size, qweight_name, g_idx_name=None, next_g_idx_name=None, stable=False): + qweight = state_dict[qweight_name] + + # unpack qweight + mask = 2**bits - 1 + pack_size = 32 // bits + unpacked_qweight = torch.zeros((qweight.shape[0]*pack_size, qweight.shape[1]), dtype=torch.int) + for i in range(pack_size): + unpacked_qweight[i::pack_size] = (qweight >> (i*bits)) & (mask) + + # reorder rows conditionally + if not (g_idx_name is None): + g_idx = state_dict[g_idx_name] + row_perm = torch.argsort(g_idx, stable=stable) + unpacked_qweight = unpacked_qweight[row_perm] + + # reorder columns conditionally + if not (next_g_idx_name is None): + next_g_idx = state_dict[next_g_idx_name] + col_perm = torch.argsort(next_g_idx, stable=stable) + unpacked_qweight = unpacked_qweight[:,col_perm] + + # pack qweight + packed_qweight = torch.zeros_like(qweight) + for i in range(pack_size): + packed_qweight |= (unpacked_qweight[i::pack_size] & mask) << (i*bits) + + # replace qweight with new reordered one in state_dict + print(f'replacing {qweight_name}') + state_dict[qweight_name] = packed_qweight + + if not (g_idx_name is None): + print(f'replacing {g_idx_name}') + state_dict[g_idx_name] = torch.arange(0, len(g_idx), dtype=torch.int) // group_size \ No newline at end of file diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 312f4d5d..5f985287 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -17,6 +17,8 @@ HAS_GPTQ_CUDA = False GPTQ_CUDA_TYPE = os.getenv("GPTQ_CUDA_TYPE", "exllama").lower() GPTQ_CUDA_LINEAR = None +# TODO: should disable TP-aware GPTQ automatically if deployment is single GPU +IS_TP_AWARE_GPTQ = (os.getenv("ENABLE_TP_AWARE_GPTQ","False").lower() not in ["false", "0"]) if torch.cuda.is_available(): try: @@ -279,13 +281,13 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: class TensorParallelColumnLinear(SuperLayer): @classmethod - def load(cls, config, prefix: str, weights, bias: bool): - return cls.load_multi(config, [prefix], weights, bias, dim=0) + def load(cls, config, prefix: str, weights, bias: bool, col_perm=None): + return cls.load_multi(config, [prefix], weights, bias, dim=0, col_perm=col_perm) @classmethod - def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int): + def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int, col_perm=None): weight = weights.get_multi_weights_col( - prefixes, quantize=config.quantize, dim=dim + prefixes, quantize=config.quantize, dim=dim, col_perm=col_perm ) if bias: @@ -303,8 +305,8 @@ def __init__(self, linear, process_group): self.process_group = process_group @classmethod - def load(cls, config, prefix: str, weights, bias: bool): - weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) + def load(cls, config, prefix: str, weights, bias: bool, row_perm=None, noshard=False): + weight = weights.get_multi_weights_row(prefix, quantize=config.quantize, row_perm=row_perm, noshard=noshard) if bias and weights.process_group.rank() == 0: # Rank is only on the first rank process bias = weights.get_tensor(f"{prefix}.bias") diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 3a53eb36..b02de5b4 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -10,6 +10,44 @@ QUANTIZE_CONFIG_FILENAME = "quantize_config.json" +def unpack(x, dim, bits=4): + return unpack_row(x, bits) if dim == 0 else unpack_col(x, bits) + +def unpack_col(x, bits): + mask = 2**bits - 1 + pack_size = 32 // bits + unpacked_x = torch.zeros((x.shape[0], x.shape[1]*pack_size), dtype=torch.int) + for i in range(pack_size): + unpacked_x[:, i::pack_size] = (x >> (i*bits)) & (mask) + return unpacked_x + +def unpack_row(x, bits): + mask = 2**bits - 1 + pack_size = 32 // bits + unpacked_x = torch.zeros((x.shape[0]*pack_size, x.shape[1]), dtype=torch.int) + for i in range(pack_size): + unpacked_x[i::pack_size] = (x >> (i*bits)) & (mask) + return unpacked_x + + +def pack(x, dim, bits=4): + return pack_row(x, bits) if dim == 0 else pack_col(x, bits) + +def pack_col(x, bits): + mask = 2**bits - 1 + pack_size = 32 // bits + packed_x = torch.zeros((x.shape[0], x.shape[1]//pack_size), dtype=torch.int) + for i in range(pack_size): + packed_x |= (x[:, i::pack_size] & mask) << (i*bits) + return packed_x + +def pack_row(x, bits): + mask = 2**bits - 1 + pack_size = 32 // bits + packed_x = torch.zeros((x.shape[0]//pack_size, x.shape[1]), dtype=torch.int) + for i in range(pack_size): + packed_x |= (x[i::pack_size] & mask) << (i*bits) + return packed_x class Weights: def __init__( @@ -101,7 +139,7 @@ def get_partial_sharded(self, tensor_name: str, dim: int): tensor = tensor.to(device=self.device) return tensor - def get_sharded(self, tensor_name: str, dim: int): + def get_sharded(self, tensor_name: str, dim: int, perm=None, packed=False): filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) slice_ = f.get_slice(tensor_name) @@ -110,17 +148,53 @@ def get_sharded(self, tensor_name: str, dim: int): assert ( size % world_size == 0 ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" - return self.get_partial_sharded(tensor_name, dim) + if perm is None: + return self.get_partial_sharded(tensor_name, dim) + else: + return self.get_shuffle_sharded(tensor_name, dim, perm, packed) + + def get_shuffle_sharded(self, tensor_name: str, dim: int, perm, packed: bool): + filename, tensor_name = self.get_filename(tensor_name) + world_size = self.process_group.size() + rank = self.process_group.rank() + + f = self._get_handle(filename) + tensor = f.get_tensor(tensor_name) + perm = perm.to(device=tensor.device) + size = tensor.shape[dim] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size - def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): + # TODO: pack-unpack on cuda to speed up this part + if dim == 0: + if packed: + tensor = pack(unpack(tensor, dim)[perm], dim)[start:stop] + else: + tensor = tensor[perm][start:stop] + elif dim == 1: + if packed: + tensor = pack(unpack(tensor, dim)[:, perm], dim)[:, start:stop] + else: + tensor = tensor[:, perm][:, start:stop] + else: + raise NotImplementedError("Let's make that generic when needed") + # Special case for gptq which shouldn't convert + # u4 which are disguised as int32 + if tensor.dtype != torch.int32: + tensor = tensor.to(dtype=self.dtype) + tensor = tensor.to(device=self.device) + return tensor + + def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int, col_perm=None): if quantize == "gptq": try: - qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1) + qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1, perm=col_perm, packed=False) for p in prefixes], dim=1) except RuntimeError: raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`") - qzeros = torch.cat([self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1) - scales = torch.cat([self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1) + qzeros = torch.cat([self.get_sharded(f"{p}.qzeros", dim=1, perm=col_perm, packed=True) for p in prefixes], dim=1) + scales = torch.cat([self.get_sharded(f"{p}.scales", dim=1, perm=col_perm, packed=False) for p in prefixes], dim=1) w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] for w2 in w[1:]: torch.testing.assert_close(w2, w[0]) @@ -141,39 +215,36 @@ def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): weight = torch.cat(w, dim=dim) return weight - def get_multi_weights_row(self, prefix: str, quantize: str): + def get_multi_weights_row(self, prefix: str, quantize: str, row_perm=None, noshard=False): if quantize == "gptq": bits, groupsize = self._get_gptq_params() - use_gptq_cuda = bits == 4 - - if self.process_group.size() > 1: - g_idx = self.get_tensor(f"{prefix}.g_idx") - if g_idx is not None: - if not torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32)) and not (g_idx == 0).all(): - # Exllama implementation does not support row tensor parallelism with act-order, as - # it would require to reorder input activations that are split unto several GPUs - use_gptq_cuda = False + from text_generation_server.utils.layers import HAS_GPTQ_CUDA, IS_TP_AWARE_GPTQ + is_preshuffle = (row_perm != None) + is_masked_matmul = noshard + assert (is_preshuffle != is_masked_matmul + or not (is_preshuffle or is_masked_matmul)), f"TP-aware optimization can't both be enabled at the same time {is_preshuffle=}, {is_masked_matmul=}" + use_gptq_cuda = (bits == 4) and HAS_GPTQ_CUDA and (IS_TP_AWARE_GPTQ and (is_preshuffle or is_masked_matmul)) + if self.process_group.rank == 0: + if use_gptq_cuda: + logger.info(f"Using GPTQ cuda kernels for row {prefix}") + else: + logger.warning( + "GPTQ cuda kernels (which are faster) could have been used, but are disabled via the DISABLE_EXLLAMA env var," + " or not currently installed, try using BUILD_EXTENSIONS=True" + ) try: - qweight = self.get_sharded(f"{prefix}.qweight", dim=0) + qweight = self.get_sharded(f"{prefix}.qweight", + dim=0, + perm=row_perm if use_gptq_cuda else None, + packed=True, + ) if not is_masked_matmul else self.get_tensor(f"{prefix}.qweight") except RuntimeError: raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`") - from text_generation_server.utils.layers import HAS_GPTQ_CUDA - if use_gptq_cuda: - use_gptq_cuda = HAS_GPTQ_CUDA - if self.process_group.rank == 0: - if use_gptq_cuda: - logger.info(f"Using GPTQ cuda kernels for row {prefix}") - else: - logger.warning( - "GPTQ cuda kernels (which are faster) could have been used, but are disabled via the DISABLE_EXLLAMA env var," - " or not currently installed, try using BUILD_EXTENSIONS=True" - ) - if use_gptq_cuda: - if groupsize >= 0: + if groupsize >= 0 and not is_masked_matmul: # Exllama reorders the weights in advance and the activations on the fly, thus # the scales and zero-points do not need to be reordered. qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) @@ -183,7 +254,7 @@ def get_multi_weights_row(self, prefix: str, quantize: str): scales = self.get_tensor(f"{prefix}.scales") # For tp > 1, at this point we know we do not use act-order - if self.process_group.size() == 1: + if (self.process_group.size() == 1 or is_masked_matmul) and not is_preshuffle: g_idx = self.get_tensor(f"{prefix}.g_idx") else: g_idx = None @@ -197,7 +268,7 @@ def get_multi_weights_row(self, prefix: str, quantize: str): weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_gptq_cuda) else: - weight = self.get_sharded(f"{prefix}.weight", dim=1) + weight = self.get_sharded(f"{prefix}.weight", dim=1) if not noshard else self.get_tensor(f"{prefix}.weight") return weight def _get_gptq_params(self) -> Tuple[int, int]: