Skip to content

Commit 025b944

Browse files
committed
Apply suggestion on GPTQ buffer setup
Signed-off-by: cyang49 <[email protected]> Signed-off-by: Chih-Chieh-Yang <[email protected]>
1 parent 96f3cd3 commit 025b944

File tree

1 file changed

+23
-25
lines changed

1 file changed

+23
-25
lines changed

server/text_generation_server/server.py

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -278,31 +278,29 @@ async def serve_inner(
278278

279279
if quantize == "gptq" and deployment_framework == "tgis_native":
280280
from text_generation_server.utils.layers import HAS_GPTQ_CUDA, EXLLAMA_VERSION
281-
if HAS_GPTQ_CUDA:
282-
if EXLLAMA_VERSION is not None:
283-
try:
284-
# When using GPTQ, Exllama kernels need some global kernels
285-
# For which we have the final shapes only after the model has loaded
286-
# This will allocate those buffers.
287-
288-
if EXLLAMA_VERSION == "1":
289-
from text_generation_server.utils.gptq.exllama import (
290-
create_exllama_buffers, set_device,
291-
)
292-
set_device(device)
293-
create_exllama_buffers(max_sequence_length)
294-
else:
295-
assert EXLLAMA_VERSION == "2"
296-
from text_generation_server.utils.gptq.exllamav2 import (
297-
set_device, Ex4bitLinearV2,
298-
)
299-
set_device(device)
300-
for _, submodule in model.model.named_modules():
301-
if isinstance(submodule, Ex4bitLinearV2):
302-
submodule.post_init() # make q matrix and set scratch space
303-
304-
except ImportError:
305-
print("WARN: Error setting up GPTQ exllama buffers")
281+
if HAS_GPTQ_CUDA and EXLLAMA_VERSION is not None:
282+
try:
283+
# When using GPTQ, Exllama kernels need some global kernels
284+
# For which we have the final shapes only after the model has loaded
285+
# This will allocate those buffers.
286+
if EXLLAMA_VERSION == "1":
287+
from text_generation_server.utils.gptq.exllama import (
288+
create_exllama_buffers, set_device,
289+
)
290+
set_device(device)
291+
create_exllama_buffers(max_sequence_length)
292+
elif EXLLAMA_VERSION == "2":
293+
from text_generation_server.utils.gptq.exllamav2 import (
294+
set_device, Ex4bitLinearV2,
295+
)
296+
set_device(device)
297+
for _, submodule in model.model.named_modules():
298+
if isinstance(submodule, Ex4bitLinearV2):
299+
submodule.post_init() # make q matrix and set scratch space
300+
else:
301+
raise ValueError(f"Unsupported {EXLLAMA_VERSION=}")
302+
except ImportError:
303+
print("WARN: Error setting up GPTQ exllama buffers")
306304

307305
if local_rank == 0 and device.type == "cuda":
308306
# Log GPU memory stats at startup

0 commit comments

Comments
 (0)