Skip to content

Performance Optimizations for TP-Aware GPTQ #67

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -222,12 +222,22 @@ def __init__(
weights=weights,
bias=False,
)

noshard_o_proj = False
if config.quantize == 'gptq':
from text_generation_server.utils.layers import IS_TP_AWARE_GPTQ
noshard_o_proj = IS_TP_AWARE_GPTQ

self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
noshard=noshard_o_proj, # Don't shard o_proj weight matrix if TP-aware optimization is desired
)
self.noshard_o_proj = noshard_o_proj
self.world_size = weights.process_group.size()
self.rank = weights.process_group.rank()

def forward(
self,
Expand Down Expand Up @@ -285,9 +295,19 @@ def forward(
1,
False,
)
attn_output = attn_output.reshape(-1, self.num_heads * self.head_size)

return self.o_proj(attn_output.reshape(-1, self.num_heads * self.head_size))
# TP-aware Masked Matmul Optimization by zero filling the activation
# and multiply with full weight matrix in o_proj
if self.noshard_o_proj:
shard_size = attn_output.shape[1]
# assert shard_size*self.world_size == self.o_proj.linear.height
zf_attn_output = torch.zeros((attn_output.shape[0], shard_size*self.world_size), dtype=attn_output.dtype, device=attn_output.device)
start_idx = self.rank * shard_size
zf_attn_output[:, start_idx:start_idx+shard_size] = attn_output
attn_output = zf_attn_output

return self.o_proj(attn_output)

class LlamaMLP(nn.Module):
def __init__(self, prefix, config, weights):
Expand All @@ -303,19 +323,32 @@ def __init__(self, prefix, config, weights):
else "none",
)
)

# For TP-aware preshuffle optimization, load the g_idx of down_proj for computing perm
# When perm==None the original unoptimized control path is taken
perm = None
if config.quantize=="gptq":
from text_generation_server.utils.layers import IS_TP_AWARE_GPTQ
if IS_TP_AWARE_GPTQ:
down_proj_g_idx = weights.get_tensor(f"{prefix}.down_proj.g_idx")
if down_proj_g_idx is not None:
perm = torch.argsort(down_proj_g_idx)

# Fuse gate and up proj
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
col_perm=perm,
)
self.down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
row_perm=perm,
)
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
Expand Down
65 changes: 65 additions & 0 deletions server/text_generation_server/utils/gptq/shuffle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import torch

# Shuffle columns of scales
def shuffle_and_replace_scales(state_dict, scales_name, col_perm):
scales = state_dict[scales_name]
assert len(col_perm) == scales.shape[1]

shuffled_scales = scales[:,col_perm]
state_dict[scales_name] = shuffled_scales

def unpack_shuffle_repack_and_replace_qzeros(state_dict, bits, qzeros_name, col_perm):
qzeros = state_dict[qzeros_name]
mask = 2**bits - 1
pack_size = 32 // bits
assert len(col_perm) == qzeros.shape[1] * pack_size

#unpack
unpacked_qzeros = torch.zeros((qzeros.shape[0], qzeros.shape[1]*pack_size), dtype=torch.int)
for i in range(pack_size):
unpacked_qzeros[:, i::pack_size] = (qzeros >> (i*bits)) & (mask)

# shuffle
shuffled_qzeros = unpacked_qzeros[:,col_perm]

# repack
packed_qzeros = torch.zeros_like(qzeros)
for i in range(pack_size):
packed_qzeros |= (shuffled_qzeros[:, i::pack_size] & mask) << (i*bits)

state_dict[qzeros_name] = packed_qzeros

def shuffle_and_replace_qweight(state_dict, bits, group_size, qweight_name, g_idx_name=None, next_g_idx_name=None, stable=False):
qweight = state_dict[qweight_name]

# unpack qweight
mask = 2**bits - 1
pack_size = 32 // bits
unpacked_qweight = torch.zeros((qweight.shape[0]*pack_size, qweight.shape[1]), dtype=torch.int)
for i in range(pack_size):
unpacked_qweight[i::pack_size] = (qweight >> (i*bits)) & (mask)

# reorder rows conditionally
if not (g_idx_name is None):
g_idx = state_dict[g_idx_name]
row_perm = torch.argsort(g_idx, stable=stable)
unpacked_qweight = unpacked_qweight[row_perm]

# reorder columns conditionally
if not (next_g_idx_name is None):
next_g_idx = state_dict[next_g_idx_name]
col_perm = torch.argsort(next_g_idx, stable=stable)
unpacked_qweight = unpacked_qweight[:,col_perm]

# pack qweight
packed_qweight = torch.zeros_like(qweight)
for i in range(pack_size):
packed_qweight |= (unpacked_qweight[i::pack_size] & mask) << (i*bits)

# replace qweight with new reordered one in state_dict
print(f'replacing {qweight_name}')
state_dict[qweight_name] = packed_qweight

if not (g_idx_name is None):
print(f'replacing {g_idx_name}')
state_dict[g_idx_name] = torch.arange(0, len(g_idx), dtype=torch.int) // group_size
14 changes: 8 additions & 6 deletions server/text_generation_server/utils/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
HAS_GPTQ_CUDA = False
GPTQ_CUDA_TYPE = os.getenv("GPTQ_CUDA_TYPE", "exllama").lower()
GPTQ_CUDA_LINEAR = None
# TODO: should disable TP-aware GPTQ automatically if deployment is single GPU
IS_TP_AWARE_GPTQ = (os.getenv("ENABLE_TP_AWARE_GPTQ","False").lower() not in ["false", "0"])

if torch.cuda.is_available():
try:
Expand Down Expand Up @@ -279,13 +281,13 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:

class TensorParallelColumnLinear(SuperLayer):
@classmethod
def load(cls, config, prefix: str, weights, bias: bool):
return cls.load_multi(config, [prefix], weights, bias, dim=0)
def load(cls, config, prefix: str, weights, bias: bool, col_perm=None):
return cls.load_multi(config, [prefix], weights, bias, dim=0, col_perm=col_perm)

@classmethod
def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int, col_perm=None):
weight = weights.get_multi_weights_col(
prefixes, quantize=config.quantize, dim=dim
prefixes, quantize=config.quantize, dim=dim, col_perm=col_perm
)

if bias:
Expand All @@ -303,8 +305,8 @@ def __init__(self, linear, process_group):
self.process_group = process_group

@classmethod
def load(cls, config, prefix: str, weights, bias: bool):
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
def load(cls, config, prefix: str, weights, bias: bool, row_perm=None, noshard=False):
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize, row_perm=row_perm, noshard=noshard)
if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process
bias = weights.get_tensor(f"{prefix}.bias")
Expand Down
135 changes: 103 additions & 32 deletions server/text_generation_server/utils/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,44 @@

QUANTIZE_CONFIG_FILENAME = "quantize_config.json"

def unpack(x, dim, bits=4):
return unpack_row(x, bits) if dim == 0 else unpack_col(x, bits)

def unpack_col(x, bits):
mask = 2**bits - 1
pack_size = 32 // bits
unpacked_x = torch.zeros((x.shape[0], x.shape[1]*pack_size), dtype=torch.int)
for i in range(pack_size):
unpacked_x[:, i::pack_size] = (x >> (i*bits)) & (mask)
return unpacked_x

def unpack_row(x, bits):
mask = 2**bits - 1
pack_size = 32 // bits
unpacked_x = torch.zeros((x.shape[0]*pack_size, x.shape[1]), dtype=torch.int)
for i in range(pack_size):
unpacked_x[i::pack_size] = (x >> (i*bits)) & (mask)
return unpacked_x


def pack(x, dim, bits=4):
return pack_row(x, bits) if dim == 0 else pack_col(x, bits)

def pack_col(x, bits):
mask = 2**bits - 1
pack_size = 32 // bits
packed_x = torch.zeros((x.shape[0], x.shape[1]//pack_size), dtype=torch.int)
for i in range(pack_size):
packed_x |= (x[:, i::pack_size] & mask) << (i*bits)
return packed_x

def pack_row(x, bits):
mask = 2**bits - 1
pack_size = 32 // bits
packed_x = torch.zeros((x.shape[0]//pack_size, x.shape[1]), dtype=torch.int)
for i in range(pack_size):
packed_x |= (x[i::pack_size] & mask) << (i*bits)
return packed_x

class Weights:
def __init__(
Expand Down Expand Up @@ -101,7 +139,7 @@ def get_partial_sharded(self, tensor_name: str, dim: int):
tensor = tensor.to(device=self.device)
return tensor

def get_sharded(self, tensor_name: str, dim: int):
def get_sharded(self, tensor_name: str, dim: int, perm=None, packed=False):
filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name)
Expand All @@ -110,17 +148,53 @@ def get_sharded(self, tensor_name: str, dim: int):
assert (
size % world_size == 0
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
return self.get_partial_sharded(tensor_name, dim)
if perm is None:
return self.get_partial_sharded(tensor_name, dim)
else:
return self.get_shuffle_sharded(tensor_name, dim, perm, packed)

def get_shuffle_sharded(self, tensor_name: str, dim: int, perm, packed: bool):
filename, tensor_name = self.get_filename(tensor_name)
world_size = self.process_group.size()
rank = self.process_group.rank()

f = self._get_handle(filename)
tensor = f.get_tensor(tensor_name)
perm = perm.to(device=tensor.device)
size = tensor.shape[dim]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size

def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
# TODO: pack-unpack on cuda to speed up this part
if dim == 0:
if packed:
tensor = pack(unpack(tensor, dim)[perm], dim)[start:stop]
else:
tensor = tensor[perm][start:stop]
elif dim == 1:
if packed:
tensor = pack(unpack(tensor, dim)[:, perm], dim)[:, start:stop]
else:
tensor = tensor[:, perm][:, start:stop]
else:
raise NotImplementedError("Let's make that generic when needed")
# Special case for gptq which shouldn't convert
# u4 which are disguised as int32
if tensor.dtype != torch.int32:
tensor = tensor.to(dtype=self.dtype)
tensor = tensor.to(device=self.device)
return tensor

def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int, col_perm=None):
if quantize == "gptq":
try:
qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1)
qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1, perm=col_perm, packed=False) for p in prefixes], dim=1)
except RuntimeError:
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")

qzeros = torch.cat([self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1)
scales = torch.cat([self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1)
qzeros = torch.cat([self.get_sharded(f"{p}.qzeros", dim=1, perm=col_perm, packed=True) for p in prefixes], dim=1)
scales = torch.cat([self.get_sharded(f"{p}.scales", dim=1, perm=col_perm, packed=False) for p in prefixes], dim=1)
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
Expand All @@ -141,39 +215,36 @@ def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
weight = torch.cat(w, dim=dim)
return weight

def get_multi_weights_row(self, prefix: str, quantize: str):
def get_multi_weights_row(self, prefix: str, quantize: str, row_perm=None, noshard=False):
if quantize == "gptq":
bits, groupsize = self._get_gptq_params()

use_gptq_cuda = bits == 4

if self.process_group.size() > 1:
g_idx = self.get_tensor(f"{prefix}.g_idx")
if g_idx is not None:
if not torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32)) and not (g_idx == 0).all():
# Exllama implementation does not support row tensor parallelism with act-order, as
# it would require to reorder input activations that are split unto several GPUs
use_gptq_cuda = False
from text_generation_server.utils.layers import HAS_GPTQ_CUDA, IS_TP_AWARE_GPTQ
is_preshuffle = (row_perm != None)
is_masked_matmul = noshard
assert (is_preshuffle != is_masked_matmul
or not (is_preshuffle or is_masked_matmul)), f"TP-aware optimization can't both be enabled at the same time {is_preshuffle=}, {is_masked_matmul=}"

use_gptq_cuda = (bits == 4) and HAS_GPTQ_CUDA and (IS_TP_AWARE_GPTQ and (is_preshuffle or is_masked_matmul))
if self.process_group.rank == 0:
if use_gptq_cuda:
logger.info(f"Using GPTQ cuda kernels for row {prefix}")
else:
logger.warning(
"GPTQ cuda kernels (which are faster) could have been used, but are disabled via the DISABLE_EXLLAMA env var,"
" or not currently installed, try using BUILD_EXTENSIONS=True"
)
try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
qweight = self.get_sharded(f"{prefix}.qweight",
dim=0,
perm=row_perm if use_gptq_cuda else None,
packed=True,
) if not is_masked_matmul else self.get_tensor(f"{prefix}.qweight")
except RuntimeError:
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")

from text_generation_server.utils.layers import HAS_GPTQ_CUDA
if use_gptq_cuda:
use_gptq_cuda = HAS_GPTQ_CUDA
if self.process_group.rank == 0:
if use_gptq_cuda:
logger.info(f"Using GPTQ cuda kernels for row {prefix}")
else:
logger.warning(
"GPTQ cuda kernels (which are faster) could have been used, but are disabled via the DISABLE_EXLLAMA env var,"
" or not currently installed, try using BUILD_EXTENSIONS=True"
)

if use_gptq_cuda:
if groupsize >= 0:
if groupsize >= 0 and not is_masked_matmul:
# Exllama reorders the weights in advance and the activations on the fly, thus
# the scales and zero-points do not need to be reordered.
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
Expand All @@ -183,7 +254,7 @@ def get_multi_weights_row(self, prefix: str, quantize: str):
scales = self.get_tensor(f"{prefix}.scales")

# For tp > 1, at this point we know we do not use act-order
if self.process_group.size() == 1:
if (self.process_group.size() == 1 or is_masked_matmul) and not is_preshuffle:
g_idx = self.get_tensor(f"{prefix}.g_idx")
else:
g_idx = None
Expand All @@ -197,7 +268,7 @@ def get_multi_weights_row(self, prefix: str, quantize: str):

weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_gptq_cuda)
else:
weight = self.get_sharded(f"{prefix}.weight", dim=1)
weight = self.get_sharded(f"{prefix}.weight", dim=1) if not noshard else self.get_tensor(f"{prefix}.weight")
return weight

def _get_gptq_params(self) -> Tuple[int, int]:
Expand Down