Skip to content

Commit eb212ba

Browse files
authored
Merge pull request #1079 from guardrails-ai/custom-llms
Fix Custom LLM Support
2 parents fd0f83b + 513b3e3 commit eb212ba

File tree

21 files changed

+779
-180
lines changed

21 files changed

+779
-180
lines changed

docs/how_to_guides/using_llms.md

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,3 +289,49 @@ for chunk in stream_chunk_generator
289289
## Other LLMs
290290

291291
See LiteLLM’s documentation [here](https://docs.litellm.ai/docs/providers) for details on many other llms.
292+
293+
## Custom LLM Wrappers
294+
In case you're using an LLM that isn't natively supported by Guardrails and you don't want to use LiteLLM, you can build a custom LLM API wrapper. In order to use a custom LLM, create a function that accepts a positional argument for the prompt as a string and any other arguments that you want to pass to the LLM API as keyword args. The function should return the output of the LLM API as a string.
295+
296+
```python
297+
from guardrails import Guard
298+
from guardrails.hub import ProfanityFree
299+
300+
# Create a Guard class
301+
guard = Guard().use(ProfanityFree())
302+
303+
# Function that takes the prompt as a string and returns the LLM output as string
304+
def my_llm_api(
305+
prompt: Optional[str] = None,
306+
*,
307+
instructions: Optional[str] = None,
308+
msg_history: Optional[list[dict]] = None,
309+
**kwargs
310+
) -> str:
311+
"""Custom LLM API wrapper.
312+
313+
At least one of prompt, instruction or msg_history should be provided.
314+
315+
Args:
316+
prompt (str): The prompt to be passed to the LLM API
317+
instruction (str): The instruction to be passed to the LLM API
318+
msg_history (list[dict]): The message history to be passed to the LLM API
319+
**kwargs: Any additional arguments to be passed to the LLM API
320+
321+
Returns:
322+
str: The output of the LLM API
323+
"""
324+
325+
# Call your LLM API here
326+
# What you pass to the llm will depend on what arguments it accepts.
327+
llm_output = some_llm(prompt, instructions, msg_history, **kwargs)
328+
329+
return llm_output
330+
331+
# Wrap your LLM API call
332+
validated_response = guard(
333+
my_llm_api,
334+
prompt="Can you generate a list of 10 things that are not food?",
335+
**kwargs,
336+
)
337+
```

guardrails/applications/text2sql.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import asyncio
22
import json
33
import os
4+
import openai
45
from string import Template
56
from typing import Callable, Dict, Optional, Type, cast
67

78
from guardrails.classes import ValidationOutcome
89
from guardrails.document_store import DocumentStoreBase, EphemeralDocumentStore
910
from guardrails.embedding import EmbeddingBase, OpenAIEmbedding
1011
from guardrails.guard import Guard
11-
from guardrails.utils.openai_utils import get_static_openai_create_func
1212
from guardrails.utils.sql_utils import create_sql_driver
1313
from guardrails.vectordb import Faiss, VectorDBBase
1414

@@ -89,7 +89,7 @@ def __init__(
8989
reask_prompt: Prompt to use for reasking. Defaults to REASK_PROMPT.
9090
"""
9191
if llm_api is None:
92-
llm_api = get_static_openai_create_func()
92+
llm_api = openai.completions.create
9393

9494
self.example_formatter = example_formatter
9595
self.llm_api = llm_api

guardrails/formatters/json_formatter.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from typing import Optional, Union
2+
from typing import Dict, List, Optional, Union
33

44
from guardrails.formatters.base_formatter import BaseFormatter
55
from guardrails.llm_providers import (
@@ -99,32 +99,48 @@ def wrap_callable(self, llm_callable) -> ArbitraryCallable:
9999

100100
if isinstance(llm_callable, HuggingFacePipelineCallable):
101101
model = llm_callable.init_kwargs["pipeline"]
102-
return ArbitraryCallable(
103-
lambda p: json.dumps(
102+
103+
def fn(
104+
prompt: str,
105+
*args,
106+
instructions: Optional[str] = None,
107+
msg_history: Optional[List[Dict[str, str]]] = None,
108+
**kwargs,
109+
) -> str:
110+
return json.dumps(
104111
Jsonformer(
105112
model=model.model,
106113
tokenizer=model.tokenizer,
107114
json_schema=self.output_schema,
108-
prompt=p,
115+
prompt=prompt,
109116
)()
110117
)
111-
)
118+
119+
return ArbitraryCallable(fn)
112120
elif isinstance(llm_callable, HuggingFaceModelCallable):
113121
# This will not work because 'model_generate' is the .gen method.
114122
# model = self.api.init_kwargs["model_generate"]
115123
# Use the __self__ to grab the base mode for passing into JF.
116124
model = llm_callable.init_kwargs["model_generate"].__self__
117125
tokenizer = llm_callable.init_kwargs["tokenizer"]
118-
return ArbitraryCallable(
119-
lambda p: json.dumps(
126+
127+
def fn(
128+
prompt: str,
129+
*args,
130+
instructions: Optional[str] = None,
131+
msg_history: Optional[List[Dict[str, str]]] = None,
132+
**kwargs,
133+
) -> str:
134+
return json.dumps(
120135
Jsonformer(
121136
model=model,
122137
tokenizer=tokenizer,
123138
json_schema=self.output_schema,
124-
prompt=p,
139+
prompt=prompt,
125140
)()
126141
)
127-
)
142+
143+
return ArbitraryCallable(fn)
128144
else:
129145
raise ValueError(
130146
"JsonFormatter can only be used with HuggingFace*Callable."

guardrails/llm_providers.py

Lines changed: 61 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22

3+
import inspect
34
from typing import (
45
Any,
56
Awaitable,
@@ -27,10 +28,10 @@
2728
from guardrails.utils.openai_utils import (
2829
AsyncOpenAIClient,
2930
OpenAIClient,
30-
get_static_openai_acreate_func,
31-
get_static_openai_chat_acreate_func,
32-
get_static_openai_chat_create_func,
33-
get_static_openai_create_func,
31+
is_static_openai_acreate_func,
32+
is_static_openai_chat_acreate_func,
33+
is_static_openai_chat_create_func,
34+
is_static_openai_create_func,
3435
)
3536
from guardrails.utils.pydantic_utils import convert_pydantic_model_to_openai_fn
3637
from guardrails.utils.safe_get import safe_get
@@ -711,6 +712,26 @@ def _invoke_llm(self, prompt: str, pipeline: Any, *args, **kwargs) -> LLMRespons
711712

712713
class ArbitraryCallable(PromptCallableBase):
713714
def __init__(self, llm_api: Optional[Callable] = None, *args, **kwargs):
715+
llm_api_args = inspect.getfullargspec(llm_api)
716+
if not llm_api_args.args:
717+
raise ValueError(
718+
"Custom LLM callables must accept"
719+
" at least one positional argument for prompt!"
720+
)
721+
if not llm_api_args.varkw:
722+
raise ValueError("Custom LLM callables must accept **kwargs!")
723+
if (
724+
not llm_api_args.kwonlyargs
725+
or "instructions" not in llm_api_args.kwonlyargs
726+
or "msg_history" not in llm_api_args.kwonlyargs
727+
):
728+
warnings.warn(
729+
"We recommend including 'instructions' and 'msg_history'"
730+
" as keyword-only arguments for custom LLM callables."
731+
" Doing so ensures these arguments are not uninentionally"
732+
" passed through to other calls via **kwargs.",
733+
UserWarning,
734+
)
714735
self.llm_api = llm_api
715736
super().__init__(*args, **kwargs)
716737

@@ -784,9 +805,9 @@ def get_llm_ask(
784805
except ImportError:
785806
pass
786807

787-
if llm_api == get_static_openai_create_func():
808+
if is_static_openai_create_func(llm_api):
788809
return OpenAICallable(*args, **kwargs)
789-
if llm_api == get_static_openai_chat_create_func():
810+
if is_static_openai_chat_create_func(llm_api):
790811
return OpenAIChatCallable(*args, **kwargs)
791812

792813
try:
@@ -1190,6 +1211,26 @@ async def invoke_llm(
11901211

11911212
class AsyncArbitraryCallable(AsyncPromptCallableBase):
11921213
def __init__(self, llm_api: Callable, *args, **kwargs):
1214+
llm_api_args = inspect.getfullargspec(llm_api)
1215+
if not llm_api_args.args:
1216+
raise ValueError(
1217+
"Custom LLM callables must accept"
1218+
" at least one positional argument for prompt!"
1219+
)
1220+
if not llm_api_args.varkw:
1221+
raise ValueError("Custom LLM callables must accept **kwargs!")
1222+
if (
1223+
not llm_api_args.kwonlyargs
1224+
or "instructions" not in llm_api_args.kwonlyargs
1225+
or "msg_history" not in llm_api_args.kwonlyargs
1226+
):
1227+
warnings.warn(
1228+
"We recommend including 'instructions' and 'msg_history'"
1229+
" as keyword-only arguments for custom LLM callables."
1230+
" Doing so ensures these arguments are not uninentionally"
1231+
" passed through to other calls via **kwargs.",
1232+
UserWarning,
1233+
)
11931234
self.llm_api = llm_api
11941235
super().__init__(*args, **kwargs)
11951236

@@ -1241,7 +1282,7 @@ async def invoke_llm(self, *args, **kwargs) -> LLMResponse:
12411282

12421283

12431284
def get_async_llm_ask(
1244-
llm_api: Callable[[Any], Awaitable[Any]], *args, **kwargs
1285+
llm_api: Callable[..., Awaitable[Any]], *args, **kwargs
12451286
) -> AsyncPromptCallableBase:
12461287
try:
12471288
import litellm
@@ -1252,9 +1293,12 @@ def get_async_llm_ask(
12521293
pass
12531294

12541295
# these only work with openai v0 (None otherwise)
1255-
if llm_api == get_static_openai_acreate_func():
1296+
# We no longer support OpenAI v0
1297+
# We should drop these checks or update the logic to support
1298+
# OpenAI v1 clients instead of just static methods
1299+
if is_static_openai_acreate_func(llm_api):
12561300
return AsyncOpenAICallable(*args, **kwargs)
1257-
if llm_api == get_static_openai_chat_acreate_func():
1301+
if is_static_openai_chat_acreate_func(llm_api):
12581302
return AsyncOpenAIChatCallable(*args, **kwargs)
12591303

12601304
try:
@@ -1265,11 +1309,12 @@ def get_async_llm_ask(
12651309
except ImportError:
12661310
pass
12671311

1268-
return AsyncArbitraryCallable(*args, llm_api=llm_api, **kwargs)
1312+
if llm_api is not None:
1313+
return AsyncArbitraryCallable(*args, llm_api=llm_api, **kwargs)
12691314

12701315

12711316
def model_is_supported_server_side(
1272-
llm_api: Optional[Union[Callable, Callable[[Any], Awaitable[Any]]]] = None,
1317+
llm_api: Optional[Union[Callable, Callable[..., Awaitable[Any]]]] = None,
12731318
*args,
12741319
**kwargs,
12751320
) -> bool:
@@ -1289,17 +1334,17 @@ def model_is_supported_server_side(
12891334

12901335
# CONTINUOUS FIXME: Update with newly supported LLMs
12911336
def get_llm_api_enum(
1292-
llm_api: Callable[[Any], Awaitable[Any]], *args, **kwargs
1337+
llm_api: Callable[..., Awaitable[Any]], *args, **kwargs
12931338
) -> Optional[LLMResource]:
12941339
# TODO: Distinguish between v1 and v2
12951340
model = get_llm_ask(llm_api, *args, **kwargs)
1296-
if llm_api == get_static_openai_create_func():
1341+
if is_static_openai_create_func(llm_api):
12971342
return LLMResource.OPENAI_DOT_COMPLETION_DOT_CREATE
1298-
elif llm_api == get_static_openai_chat_create_func():
1343+
elif is_static_openai_chat_create_func(llm_api):
12991344
return LLMResource.OPENAI_DOT_CHAT_COMPLETION_DOT_CREATE
1300-
elif llm_api == get_static_openai_acreate_func():
1345+
elif is_static_openai_acreate_func(llm_api): # This is always False
13011346
return LLMResource.OPENAI_DOT_COMPLETION_DOT_ACREATE
1302-
elif llm_api == get_static_openai_chat_acreate_func():
1347+
elif is_static_openai_chat_acreate_func(llm_api): # This is always False
13031348
return LLMResource.OPENAI_DOT_CHAT_COMPLETION_DOT_ACREATE
13041349
elif isinstance(model, LiteLLMCallable):
13051350
return LLMResource.LITELLM_DOT_COMPLETION

guardrails/utils/openai_utils/__init__.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,18 @@
22
from .v1 import OpenAIClientV1 as OpenAIClient
33
from .v1 import (
44
OpenAIServiceUnavailableError,
5-
get_static_openai_acreate_func,
6-
get_static_openai_chat_acreate_func,
7-
get_static_openai_chat_create_func,
8-
get_static_openai_create_func,
5+
is_static_openai_acreate_func,
6+
is_static_openai_chat_acreate_func,
7+
is_static_openai_chat_create_func,
8+
is_static_openai_create_func,
99
)
1010

1111
__all__ = [
1212
"AsyncOpenAIClient",
1313
"OpenAIClient",
14-
"get_static_openai_create_func",
15-
"get_static_openai_chat_create_func",
16-
"get_static_openai_acreate_func",
17-
"get_static_openai_chat_acreate_func",
14+
"is_static_openai_create_func",
15+
"is_static_openai_chat_create_func",
16+
"is_static_openai_acreate_func",
17+
"is_static_openai_chat_acreate_func",
1818
"OpenAIServiceUnavailableError",
1919
]

guardrails/utils/openai_utils/v1.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, AsyncIterable, Dict, Iterable, List, cast
1+
from typing import Any, AsyncIterable, Callable, Dict, Iterable, List, Optional, cast
22

33
import openai
44

@@ -12,20 +12,30 @@
1212
from guardrails.telemetry import trace_llm_call, trace_operation
1313

1414

15-
def get_static_openai_create_func():
16-
return openai.completions.create
15+
def is_static_openai_create_func(llm_api: Optional[Callable]) -> bool:
16+
try:
17+
return llm_api == openai.completions.create
18+
except openai.OpenAIError:
19+
return False
1720

1821

19-
def get_static_openai_chat_create_func():
20-
return openai.chat.completions.create
22+
def is_static_openai_chat_create_func(llm_api: Optional[Callable]) -> bool:
23+
try:
24+
return llm_api == openai.chat.completions.create
25+
except openai.OpenAIError:
26+
return False
2127

2228

23-
def get_static_openai_acreate_func():
24-
return None
29+
def is_static_openai_acreate_func(llm_api: Optional[Callable]) -> bool:
30+
# Because the static version of this does not exist in OpenAI 1.x
31+
# Can we just drop these checks?
32+
return False
2533

2634

27-
def get_static_openai_chat_acreate_func():
28-
return None
35+
def is_static_openai_chat_acreate_func(llm_api: Optional[Callable]) -> bool:
36+
# Because the static version of this does not exist in OpenAI 1.x
37+
# Can we just drop these checks?
38+
return False
2939

3040

3141
OpenAIServiceUnavailableError = openai.APIError
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from typing import Dict, List, Optional
2+
3+
4+
def mock_llm(
5+
prompt: Optional[str] = None,
6+
*args,
7+
instructions: Optional[str] = None,
8+
msg_history: Optional[List[Dict[str, str]]] = None,
9+
**kwargs,
10+
) -> str:
11+
return ""
12+
13+
14+
async def mock_async_llm(
15+
prompt: Optional[str] = None,
16+
*args,
17+
instructions: Optional[str] = None,
18+
msg_history: Optional[List[Dict[str, str]]] = None,
19+
**kwargs,
20+
) -> str:
21+
return ""

0 commit comments

Comments
 (0)