You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
fix: move parameter validation before fit_memory_scaling_model (IBM#101)
The launch of `fit_memory_scaling_model` uses the values for `quantize`
and `dtype_str`, so those should be validated and defaulted before it is
ran.
Before this change, if `dtype_str` was set to `None` it would be passed
to `fit_memory_scaling_model` as `None` resulting in an error:
```
Shard 1: Process SpawnProcess-33:
Shard 1: Traceback (most recent call last):
Shard 1: File "/opt/tgis/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap
Shard 1: self.run()
Shard 1: File "/opt/tgis/lib/python3.11/multiprocessing/process.py", line 108, in run
Shard 1: self._target(*self._args, **self._kwargs)
Shard 1: File "/opt/tgis/lib/python3.11/site-packages/text_generation_server/utils/paged.py", line 37, in fit_memory_scaling_model
Shard 1: model = get_model(
Shard 1: ^^^^^^^^^^
Shard 1: File "/opt/tgis/lib/python3.11/site-packages/text_generation_server/models/__init__.py", line 39, in get_model
Shard 1: dtype = get_torch_dtype(dtype_str)
Shard 1: ^^^^^^^^^^^^^^^^^^^^^^^^^^
Shard 1: File "/opt/tgis/lib/python3.11/site-packages/text_generation_server/utils/dist.py", line 64, in get_torch_dtype
Shard 1: dt = getattr(torch, dtype_str, None)
Shard 1: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Shard 1: TypeError: attribute name must be string, not 'NoneType'
```
After this change, a value will always be set before calling
`fit_memory_scaling_model`.
Signed-off-by: Travis Johnson <[email protected]>
0 commit comments