Skip to content

Commit 7b24de1

Browse files
committed
Simplify exllamav2 scratch space buffer allocation
1 parent d33efb4 commit 7b24de1

File tree

2 files changed

+28
-56
lines changed

2 files changed

+28
-56
lines changed

server/text_generation_server/server.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ async def serve_inner(
262262
print(f"Using device {device}, dtype {dtype_str}, quantize {quantize}")
263263
print(model.config.__str__())
264264

265-
if quantize == "gptq":
265+
if quantize == "gptq" and deployment_framework == "hf_custom_tp":
266266
from text_generation_server.utils.layers import HAS_EXLLAMA, EXLLAMA_VERSION
267267
if HAS_EXLLAMA:
268268
try:
@@ -272,27 +272,16 @@ async def serve_inner(
272272

273273
if EXLLAMA_VERSION == "1":
274274
from text_generation_server.utils.gptq.exllama import (
275-
create_exllama_buffers,
276-
set_device,
275+
create_exllama_buffers, set_device,
277276
)
277+
set_device(device)
278+
create_exllama_buffers(max_sequence_length)
278279
else:
280+
assert EXLLAMA_VERSION == "2"
279281
from text_generation_server.utils.gptq.exllamav2 import (
280-
create_exllama_buffers,
281-
set_device,
282-
Ex4bitLinearV2,
283-
)
284-
285-
set_device(device)
286-
287-
if EXLLAMA_VERSION == "1":
288-
create_exllama_buffers(max_sequence_length)
289-
elif EXLLAMA_VERSION == "2":
290-
# NOTE: We're assuming that in this case, max_batch_weight == max_batch_tokens
291-
# This will likely need to change soon when we rework the batching parameters
292-
max_batch_tokens = max_batch_weight if max_batch_weight is not None else (
293-
max_batch_size * max_sequence_length
282+
set_device, Ex4bitLinearV2,
294283
)
295-
create_exllama_buffers(max_batch_tokens)
284+
set_device(device)
296285
for _, submodule in model.model.named_modules():
297286
if isinstance(submodule, Ex4bitLinearV2):
298287
submodule.post_init() # make q matrix and set scratch space

server/text_generation_server/utils/gptq/exllamav2.py

Lines changed: 21 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
1919
gemm_half_q_half(x, q_handle, output, force_cuda)
2020
return output.view(output_shape)
2121

22+
2223
def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
2324
"""
2425
Create Q matrix
@@ -60,63 +61,46 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
6061
temp_dq,
6162
)
6263

64+
6365
def temp_dq_size(inout_product):
6466
return inout_product * 2 + 128
6567

66-
def temp_fwd_size(outfeatures, max_batch_tokens):
67-
return outfeatures * max_batch_tokens * 4 + 128
6868

69-
def scratch_space_fixed(inout_product, outfeatures, max_batch_tokens):
70-
return temp_dq_size(inout_product) + temp_fwd_size(outfeatures, max_batch_tokens)
69+
def _elements(size_bytes):
70+
size_bytes = (size_bytes + 127) & -128 # round up to nearest multiple of 128
71+
return size_bytes // 2
72+
7173

72-
class ExLlamaV2DeviceTensors:
74+
class ExLlamaV2DeviceTensor:
7375
def __init__(self, device, scratch_bytes):
7476
self.device = device
75-
self.scratch_bytes = scratch_bytes
76-
self.scratch = None
77-
78-
def prepare(self):
79-
print_rank_n(f"Allocating {self.scratch_bytes // (1024 * 1024)} MiB for exllama v2 scratch space")
80-
self.scratch = torch.empty((self.scratch_bytes // 2,), dtype=torch.half, device=self.device)
77+
print_rank_n(f"Allocating {scratch_bytes // (1024 * 1024)} MiB for exllama v2 scratch space")
78+
self.scratch = torch.empty(
79+
_elements(scratch_bytes), dtype=torch.half, device=self.device
80+
)
8181

8282
def get_scratch_slice(self, size_bytes):
83-
if self.scratch is None:
84-
self.prepare()
83+
size_half = _elements(size_bytes)
84+
return self.scratch[:size_half]
8585

86-
size_bytes = ((size_bytes + 127) // 128) * 128
87-
size_half = size_bytes // 2
88-
scratch_slice = self.scratch.narrow(0, 0, size_half)
89-
return scratch_slice
9086

91-
# Max number of output features, used by temp_fwd_size calculation
92-
MAX_OUT_FEATURES = 1
93-
# Max of (infeatures * outfeatures), used by temp_dq_size calculation
94-
MAX_INOUT_PRODUCT = 1
9587
# DEVICE_TENSOR is a cuda buffer used by cublas gemm when M > 50
9688
DEVICE_TENSOR = None
9789
DEVICE = None
90+
# Max of (infeatures * outfeatures), used by temp_dq_size calculation
91+
MAX_INOUT_PRODUCT = 1
92+
9893

9994
def set_device(device):
100-
global DEVICE
95+
global DEVICE, DEVICE_TENSOR, MAX_INOUT_PRODUCT
10196
DEVICE = device
97+
DEVICE_TENSOR = ExLlamaV2DeviceTensor(DEVICE, temp_dq_size(MAX_INOUT_PRODUCT))
10298

103-
def create_exllama_buffers(max_batch_tokens: int):
104-
global DEVICE, DEVICE_TENSOR, MAX_OUT_FEATURES, MAX_INOUT_PRODUCT
105-
106-
assert DEVICE is not None, "call set_device first"
107-
108-
DEVICE_TENSOR = ExLlamaV2DeviceTensors(
109-
DEVICE,
110-
scratch_space_fixed(
111-
MAX_INOUT_PRODUCT,
112-
MAX_OUT_FEATURES,
113-
max_batch_tokens,
114-
))
11599

116100
class Ex4bitLinearV2(nn.Module):
117101
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
118102
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
119-
global MAX_OUT_FEATURES, MAX_INOUT_PRODUCT
103+
global MAX_INOUT_PRODUCT
120104
super().__init__()
121105
assert bits == 4
122106

@@ -134,9 +118,8 @@ def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
134118
assert self.height % 32 == 0
135119
assert self.width % 32 == 0
136120

137-
# Update max outfeatures & inout_product so far for later call to create_exllama_buffers
138-
MAX_OUT_FEATURES = max(MAX_OUT_FEATURES, self.width)
139-
MAX_INOUT_PRODUCT = max(MAX_INOUT_PRODUCT, self.width*self.height)
121+
# Update max outfeatures & inout_product so far for later call to set_device
122+
MAX_INOUT_PRODUCT = max(MAX_INOUT_PRODUCT, self.width * self.height)
140123

141124
def post_init(self):
142125
global DEVICE_TENSOR

0 commit comments

Comments
 (0)