Skip to content

Commit 07ca0b2

Browse files
Added support for launching inference endpoint with different model dtypes (#124)
* Added support for any dtype --------- Co-authored-by: Clémentine Fourrier <[email protected]>
1 parent eb0d898 commit 07ca0b2

File tree

3 files changed

+17
-2
lines changed

3 files changed

+17
-2
lines changed

src/lighteval/models/endpoint_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def __init__(
8787
"MAX_INPUT_LENGTH": "2047",
8888
"MAX_TOTAL_TOKENS": "2048",
8989
"MODEL_ID": "/repository",
90+
**config.get_dtype_args(),
9091
},
9192
"url": "ghcr.io/huggingface/text-generation-inference:1.1.0",
9293
},

src/lighteval/models/model_config.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from argparse import Namespace
2424
from dataclasses import dataclass
25-
from typing import Optional, Union
25+
from typing import Dict, Optional, Union
2626

2727
import torch
2828
from transformers import AutoConfig, BitsAndBytesConfig, GPTQConfig, PretrainedConfig
@@ -233,11 +233,24 @@ class InferenceEndpointModelConfig:
233233
region: str
234234
instance_size: str
235235
instance_type: str
236+
model_dtype: str
236237
framework: str = "pytorch"
237238
endpoint_type: str = "protected"
238239
should_reuse_existing: bool = False
239240
add_special_tokens: bool = True
240241

242+
def get_dtype_args(self) -> Dict[str, str]:
243+
model_dtype = self.model_dtype.lower()
244+
if model_dtype in ["awq", "eetq", "gptq"]:
245+
return {"QUANTIZE": model_dtype}
246+
if model_dtype == "8bit":
247+
return {"QUANTIZE": "bitsandbytes"}
248+
if model_dtype == "4bit":
249+
return {"QUANTIZE": "bitsandbytes-nf4"}
250+
if model_dtype in ["bfloat16", "float16"]:
251+
return {"DTYPE": model_dtype}
252+
return {}
253+
241254

242255
def create_model_config(args: Namespace, accelerator: Union["Accelerator", None]) -> BaseModelConfig: # noqa: C901
243256
"""
@@ -282,6 +295,7 @@ def create_model_config(args: Namespace, accelerator: Union["Accelerator", None]
282295
instance_size=args.instance_size,
283296
instance_type=args.instance_type,
284297
should_reuse_existing=args.reuse_existing,
298+
model_dtype=args.model_dtype,
285299
)
286300
return InferenceModelConfig(model=args.endpoint_model_name)
287301

src/lighteval/models/model_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def load_model_with_inference_endpoints(config: InferenceEndpointModelConfig, en
108108
model_info = ModelInfo(
109109
model_name=model.name,
110110
model_sha=model.revision,
111-
model_dtype="default",
111+
model_dtype=config.model_dtype or "default",
112112
model_size=-1,
113113
)
114114
return model, model_info

0 commit comments

Comments
 (0)