|
22 | 22 |
|
23 | 23 | from argparse import Namespace |
24 | 24 | from dataclasses import dataclass |
25 | | -from typing import Optional, Union |
| 25 | +from typing import Dict, Optional, Union |
26 | 26 |
|
27 | 27 | import torch |
28 | 28 | from transformers import AutoConfig, BitsAndBytesConfig, GPTQConfig, PretrainedConfig |
@@ -233,11 +233,24 @@ class InferenceEndpointModelConfig: |
233 | 233 | region: str |
234 | 234 | instance_size: str |
235 | 235 | instance_type: str |
| 236 | + model_dtype: str |
236 | 237 | framework: str = "pytorch" |
237 | 238 | endpoint_type: str = "protected" |
238 | 239 | should_reuse_existing: bool = False |
239 | 240 | add_special_tokens: bool = True |
240 | 241 |
|
| 242 | + def get_dtype_args(self) -> Dict[str, str]: |
| 243 | + model_dtype = self.model_dtype.lower() |
| 244 | + if model_dtype in ["awq", "eetq", "gptq"]: |
| 245 | + return {"QUANTIZE": model_dtype} |
| 246 | + if model_dtype == "8bit": |
| 247 | + return {"QUANTIZE": "bitsandbytes"} |
| 248 | + if model_dtype == "4bit": |
| 249 | + return {"QUANTIZE": "bitsandbytes-nf4"} |
| 250 | + if model_dtype in ["bfloat16", "float16"]: |
| 251 | + return {"DTYPE": model_dtype} |
| 252 | + return {} |
| 253 | + |
241 | 254 |
|
242 | 255 | def create_model_config(args: Namespace, accelerator: Union["Accelerator", None]) -> BaseModelConfig: # noqa: C901 |
243 | 256 | """ |
@@ -282,6 +295,7 @@ def create_model_config(args: Namespace, accelerator: Union["Accelerator", None] |
282 | 295 | instance_size=args.instance_size, |
283 | 296 | instance_type=args.instance_type, |
284 | 297 | should_reuse_existing=args.reuse_existing, |
| 298 | + model_dtype=args.model_dtype, |
285 | 299 | ) |
286 | 300 | return InferenceModelConfig(model=args.endpoint_model_name) |
287 | 301 |
|
|
0 commit comments