Skip to content

Commit 041fffb

Browse files
fix: move parameter validation before fit_memory_scaling_model (IBM#101)
The launch of `fit_memory_scaling_model` uses the values for `quantize` and `dtype_str`, so those should be validated and defaulted before it is ran. Before this change, if `dtype_str` was set to `None` it would be passed to `fit_memory_scaling_model` as `None` resulting in an error: ``` Shard 1: Process SpawnProcess-33: Shard 1: Traceback (most recent call last): Shard 1: File "/opt/tgis/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap Shard 1: self.run() Shard 1: File "/opt/tgis/lib/python3.11/multiprocessing/process.py", line 108, in run Shard 1: self._target(*self._args, **self._kwargs) Shard 1: File "/opt/tgis/lib/python3.11/site-packages/text_generation_server/utils/paged.py", line 37, in fit_memory_scaling_model Shard 1: model = get_model( Shard 1: ^^^^^^^^^^ Shard 1: File "/opt/tgis/lib/python3.11/site-packages/text_generation_server/models/__init__.py", line 39, in get_model Shard 1: dtype = get_torch_dtype(dtype_str) Shard 1: ^^^^^^^^^^^^^^^^^^^^^^^^^^ Shard 1: File "/opt/tgis/lib/python3.11/site-packages/text_generation_server/utils/dist.py", line 64, in get_torch_dtype Shard 1: dt = getattr(torch, dtype_str, None) Shard 1: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Shard 1: TypeError: attribute name must be string, not 'NoneType' ``` After this change, a value will always be set before calling `fit_memory_scaling_model`. Signed-off-by: Travis Johnson <[email protected]>
1 parent 9b4aea8 commit 041fffb

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

server/text_generation_server/server.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,22 @@ async def serve_inner(
273273
batch_safety_margin: int,
274274
sharded: bool = False,
275275
):
276+
if quantize not in [None, "gptq", "bitsandbytes"]:
277+
raise ValueError(f"Unrecognized quantization method specified: {quantize}")
278+
279+
if quantize is None and dtype_str == "int8":
280+
print_rank_n("Inferring quantize = bitsandbytes because dtype == int8")
281+
quantize = "bitsandbytes"
282+
283+
cuda_available = torch.cuda.is_available()
284+
285+
# Default dtype based on device if not provided
286+
if dtype_str is None:
287+
dtype_str = "float16" if cuda_available else "float32"
288+
289+
if quantize is not None and not cuda_available:
290+
raise ValueError("Quantization requires CUDA")
291+
276292
if ESTIMATE_MEMORY == "auto" and PAGED_ATTENTION:
277293
# fit memory model using flash model in separate process (ensures GPU memory is entirely cleaned up)
278294
from text_generation_server.utils.paged import fit_memory_scaling_model
@@ -296,22 +312,6 @@ async def serve_inner(
296312
]
297313
local_url = server_urls[local_rank]
298314

299-
if quantize not in [None, "gptq", "bitsandbytes"]:
300-
raise ValueError(f"Unrecognized quantization method specified: {quantize}")
301-
302-
# Default dtype based on device if not provided
303-
if dtype_str is None:
304-
dtype_str = "float16" if torch.cuda.is_available() else "float32"
305-
306-
if quantize is None and dtype_str == "int8":
307-
print_rank_n("Inferring quantize = bitsandbytes because dtype == int8")
308-
quantize = "bitsandbytes"
309-
310-
cuda_available = torch.cuda.is_available()
311-
312-
if quantize is not None and not cuda_available:
313-
raise ValueError("Quantization requires CUDA")
314-
315315
# Set the fraction of cuda/gpu mem available to this process, then load the model
316316
if cuda_available and cuda_process_memory_fraction < 1:
317317
torch.cuda.set_per_process_memory_fraction(cuda_process_memory_fraction)

0 commit comments

Comments
 (0)