@@ -278,31 +278,29 @@ async def serve_inner(
278
278
279
279
if quantize == "gptq" and deployment_framework == "tgis_native" :
280
280
from text_generation_server .utils .layers import HAS_GPTQ_CUDA , EXLLAMA_VERSION
281
- if HAS_GPTQ_CUDA :
282
- if EXLLAMA_VERSION is not None :
283
- try :
284
- # When using GPTQ, Exllama kernels need some global kernels
285
- # For which we have the final shapes only after the model has loaded
286
- # This will allocate those buffers.
287
-
288
- if EXLLAMA_VERSION == "1" :
289
- from text_generation_server .utils .gptq .exllama import (
290
- create_exllama_buffers , set_device ,
291
- )
292
- set_device (device )
293
- create_exllama_buffers (max_sequence_length )
294
- else :
295
- assert EXLLAMA_VERSION == "2"
296
- from text_generation_server .utils .gptq .exllamav2 import (
297
- set_device , Ex4bitLinearV2 ,
298
- )
299
- set_device (device )
300
- for _ , submodule in model .model .named_modules ():
301
- if isinstance (submodule , Ex4bitLinearV2 ):
302
- submodule .post_init () # make q matrix and set scratch space
303
-
304
- except ImportError :
305
- print ("WARN: Error setting up GPTQ exllama buffers" )
281
+ if HAS_GPTQ_CUDA and EXLLAMA_VERSION is not None :
282
+ try :
283
+ # When using GPTQ, Exllama kernels need some global kernels
284
+ # For which we have the final shapes only after the model has loaded
285
+ # This will allocate those buffers.
286
+ if EXLLAMA_VERSION == "1" :
287
+ from text_generation_server .utils .gptq .exllama import (
288
+ create_exllama_buffers , set_device ,
289
+ )
290
+ set_device (device )
291
+ create_exllama_buffers (max_sequence_length )
292
+ elif EXLLAMA_VERSION == "2" :
293
+ from text_generation_server .utils .gptq .exllamav2 import (
294
+ set_device , Ex4bitLinearV2 ,
295
+ )
296
+ set_device (device )
297
+ for _ , submodule in model .model .named_modules ():
298
+ if isinstance (submodule , Ex4bitLinearV2 ):
299
+ submodule .post_init () # make q matrix and set scratch space
300
+ else :
301
+ raise ValueError (f"Unsupported { EXLLAMA_VERSION = } " )
302
+ except ImportError :
303
+ print ("WARN: Error setting up GPTQ exllama buffers" )
306
304
307
305
if local_rank == 0 and device .type == "cuda" :
308
306
# Log GPU memory stats at startup
0 commit comments