1
1
"""Wrapper around HuggingFace Pipeline APIs."""
2
+ import importlib .util
3
+ import logging
2
4
from typing import Any , List , Mapping , Optional
3
5
4
6
from pydantic import BaseModel , Extra
10
12
DEFAULT_TASK = "text-generation"
11
13
VALID_TASKS = ("text2text-generation" , "text-generation" )
12
14
15
+ logger = logging .getLogger ()
16
+
13
17
14
18
class HuggingFacePipeline (LLM , BaseModel ):
15
19
"""Wrapper around HuggingFace Pipeline API.
@@ -56,6 +60,7 @@ def from_model_id(
56
60
cls ,
57
61
model_id : str ,
58
62
task : str ,
63
+ device : int = - 1 ,
59
64
model_kwargs : Optional [dict ] = None ,
60
65
** kwargs : Any ,
61
66
) -> LLM :
@@ -68,8 +73,16 @@ def from_model_id(
68
73
)
69
74
from transformers import pipeline as hf_pipeline
70
75
71
- _model_kwargs = model_kwargs or {}
72
- tokenizer = AutoTokenizer .from_pretrained (model_id , ** _model_kwargs )
76
+ except ImportError :
77
+ raise ValueError (
78
+ "Could not import transformers python package. "
79
+ "Please it install it with `pip install transformers`."
80
+ )
81
+
82
+ _model_kwargs = model_kwargs or {}
83
+ tokenizer = AutoTokenizer .from_pretrained (model_id , ** _model_kwargs )
84
+
85
+ try :
73
86
if task == "text-generation" :
74
87
model = AutoModelForCausalLM .from_pretrained (model_id , ** _model_kwargs )
75
88
elif task == "text2text-generation" :
@@ -79,25 +92,47 @@ def from_model_id(
79
92
f"Got invalid task { task } , "
80
93
f"currently only { VALID_TASKS } are supported"
81
94
)
82
- pipeline = hf_pipeline (
83
- task = task , model = model , tokenizer = tokenizer , model_kwargs = _model_kwargs
84
- )
85
- if pipeline .task not in VALID_TASKS :
95
+ except ImportError as e :
96
+ raise ValueError (
97
+ f"Could not load the { task } model due to missing dependencies."
98
+ ) from e
99
+
100
+ if importlib .util .find_spec ("torch" ) is not None :
101
+ import torch
102
+
103
+ cuda_device_count = torch .cuda .device_count ()
104
+ if device < - 1 or (device >= cuda_device_count ):
86
105
raise ValueError (
87
- f"Got invalid task { pipeline . task } , "
88
- f"currently only { VALID_TASKS } are supported "
106
+ f"Got device== { device } , "
107
+ f"device is required to be within [-1, { cuda_device_count } ) "
89
108
)
90
- return cls (
91
- pipeline = pipeline ,
92
- model_id = model_id ,
93
- model_kwargs = _model_kwargs ,
94
- ** kwargs ,
95
- )
96
- except ImportError :
109
+ if device < 0 and cuda_device_count > 0 :
110
+ logger .warning (
111
+ "Device has %d GPUs available. "
112
+ "Provide device={deviceId} to `from_model_id` to use available"
113
+ "GPUs for execution. deviceId is -1 (default) for CPU and "
114
+ "can be a positive integer associated with CUDA device id." ,
115
+ cuda_device_count ,
116
+ )
117
+
118
+ pipeline = hf_pipeline (
119
+ task = task ,
120
+ model = model ,
121
+ tokenizer = tokenizer ,
122
+ device = device ,
123
+ model_kwargs = _model_kwargs ,
124
+ )
125
+ if pipeline .task not in VALID_TASKS :
97
126
raise ValueError (
98
- "Could not import transformers python package. "
99
- "Please it install it with `pip install transformers`. "
127
+ f"Got invalid task { pipeline . task } , "
128
+ f"currently only { VALID_TASKS } are supported "
100
129
)
130
+ return cls (
131
+ pipeline = pipeline ,
132
+ model_id = model_id ,
133
+ model_kwargs = _model_kwargs ,
134
+ ** kwargs ,
135
+ )
101
136
102
137
@property
103
138
def _identifying_params (self ) -> Mapping [str , Any ]:
0 commit comments