Skip to content

Commit 5bb2952

Browse files
Harrison/hf pipeline (#780)
Co-authored-by: Parth Chadha <[email protected]>
1 parent c658f0a commit 5bb2952

File tree

1 file changed

+52
-17
lines changed

1 file changed

+52
-17
lines changed

langchain/llms/huggingface_pipeline.py

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
"""Wrapper around HuggingFace Pipeline APIs."""
2+
import importlib.util
3+
import logging
24
from typing import Any, List, Mapping, Optional
35

46
from pydantic import BaseModel, Extra
@@ -10,6 +12,8 @@
1012
DEFAULT_TASK = "text-generation"
1113
VALID_TASKS = ("text2text-generation", "text-generation")
1214

15+
logger = logging.getLogger()
16+
1317

1418
class HuggingFacePipeline(LLM, BaseModel):
1519
"""Wrapper around HuggingFace Pipeline API.
@@ -56,6 +60,7 @@ def from_model_id(
5660
cls,
5761
model_id: str,
5862
task: str,
63+
device: int = -1,
5964
model_kwargs: Optional[dict] = None,
6065
**kwargs: Any,
6166
) -> LLM:
@@ -68,8 +73,16 @@ def from_model_id(
6873
)
6974
from transformers import pipeline as hf_pipeline
7075

71-
_model_kwargs = model_kwargs or {}
72-
tokenizer = AutoTokenizer.from_pretrained(model_id, **_model_kwargs)
76+
except ImportError:
77+
raise ValueError(
78+
"Could not import transformers python package. "
79+
"Please it install it with `pip install transformers`."
80+
)
81+
82+
_model_kwargs = model_kwargs or {}
83+
tokenizer = AutoTokenizer.from_pretrained(model_id, **_model_kwargs)
84+
85+
try:
7386
if task == "text-generation":
7487
model = AutoModelForCausalLM.from_pretrained(model_id, **_model_kwargs)
7588
elif task == "text2text-generation":
@@ -79,25 +92,47 @@ def from_model_id(
7992
f"Got invalid task {task}, "
8093
f"currently only {VALID_TASKS} are supported"
8194
)
82-
pipeline = hf_pipeline(
83-
task=task, model=model, tokenizer=tokenizer, model_kwargs=_model_kwargs
84-
)
85-
if pipeline.task not in VALID_TASKS:
95+
except ImportError as e:
96+
raise ValueError(
97+
f"Could not load the {task} model due to missing dependencies."
98+
) from e
99+
100+
if importlib.util.find_spec("torch") is not None:
101+
import torch
102+
103+
cuda_device_count = torch.cuda.device_count()
104+
if device < -1 or (device >= cuda_device_count):
86105
raise ValueError(
87-
f"Got invalid task {pipeline.task}, "
88-
f"currently only {VALID_TASKS} are supported"
106+
f"Got device=={device}, "
107+
f"device is required to be within [-1, {cuda_device_count})"
89108
)
90-
return cls(
91-
pipeline=pipeline,
92-
model_id=model_id,
93-
model_kwargs=_model_kwargs,
94-
**kwargs,
95-
)
96-
except ImportError:
109+
if device < 0 and cuda_device_count > 0:
110+
logger.warning(
111+
"Device has %d GPUs available. "
112+
"Provide device={deviceId} to `from_model_id` to use available"
113+
"GPUs for execution. deviceId is -1 (default) for CPU and "
114+
"can be a positive integer associated with CUDA device id.",
115+
cuda_device_count,
116+
)
117+
118+
pipeline = hf_pipeline(
119+
task=task,
120+
model=model,
121+
tokenizer=tokenizer,
122+
device=device,
123+
model_kwargs=_model_kwargs,
124+
)
125+
if pipeline.task not in VALID_TASKS:
97126
raise ValueError(
98-
"Could not import transformers python package. "
99-
"Please it install it with `pip install transformers`."
127+
f"Got invalid task {pipeline.task}, "
128+
f"currently only {VALID_TASKS} are supported"
100129
)
130+
return cls(
131+
pipeline=pipeline,
132+
model_id=model_id,
133+
model_kwargs=_model_kwargs,
134+
**kwargs,
135+
)
101136

102137
@property
103138
def _identifying_params(self) -> Mapping[str, Any]:

0 commit comments

Comments
 (0)