1
- import os
1
+ from typing import Any , Optional
2
+
2
3
import torch
3
4
from transformers .models .auto .auto_factory import _BaseAutoModelClass
4
5
5
- from text_generation_server .inference_engine .engine import BaseInferenceEngine
6
- from text_generation_server .utils .hub import TRUST_REMOTE_CODE
7
- from typing import Any , Optional
6
+ from text_generation_server .inference_engine .hf_transformers import InferenceEngine as HFTransformersInferenceEngine
8
7
9
8
10
- class InferenceEngine (BaseInferenceEngine ):
9
+ class InferenceEngine (HFTransformersInferenceEngine ):
11
10
def __init__ (
12
11
self ,
13
12
model_path : str ,
@@ -17,28 +16,12 @@ def __init__(
17
16
model_config : Optional [Any ],
18
17
max_sequence_length : Optional [int ],
19
18
) -> None :
20
- super ().__init__ (model_path , model_config )
21
-
22
- kwargs = {
23
- "pretrained_model_name_or_path" : model_path ,
24
- "device_map" : None ,
25
- "local_files_only" : True ,
26
- "trust_remote_code" : TRUST_REMOTE_CODE ,
27
- }
28
-
29
- if self .device .type == "cuda" :
30
- kwargs ["device_map" ] = "balanced_low_0" if self .world_size > 1 else "auto"
31
-
32
- if quantize == "bitsandbytes" :
33
- # using LLM.int8()
34
- kwargs ["load_in_8bit" ] = True
35
- elif quantize is not None :
36
- raise ValueError (f"{ quantize } quantization not supported by hf_accelerate engine" )
37
- else :
38
- kwargs ["torch_dtype" ] = dtype
39
-
40
- slow_but_exact = os .getenv ('BLOOM_SLOW_BUT_EXACT' , 'false' ).lower () == 'true'
41
- if slow_but_exact :
42
- kwargs ["slow_but_exact" ] = True
43
-
44
- self .model = model_class .from_pretrained (** kwargs ).requires_grad_ (False ).eval ()
19
+ super ().__init__ (
20
+ model_path ,
21
+ model_class ,
22
+ dtype ,
23
+ quantize ,
24
+ model_config ,
25
+ max_sequence_length ,
26
+ _use_accelerate = True
27
+ )
0 commit comments