Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions docs/how_to_guides/using_llms.md
Original file line number Diff line number Diff line change
Expand Up @@ -289,3 +289,49 @@ for chunk in stream_chunk_generator
## Other LLMs

See LiteLLM’s documentation [here](https://docs.litellm.ai/docs/providers) for details on many other llms.

## Custom LLM Wrappers
In case you're using an LLM that isn't natively supported by Guardrails and you don't want to use LiteLLM, you can build a custom LLM API wrapper. In order to use a custom LLM, create a function that accepts a positional argument for the prompt as a string and any other arguments that you want to pass to the LLM API as keyword args. The function should return the output of the LLM API as a string.

```python
from guardrails import Guard
from guardrails.hub import ProfanityFree

# Create a Guard class
guard = Guard().use(ProfanityFree())

# Function that takes the prompt as a string and returns the LLM output as string
def my_llm_api(
prompt: Optional[str] = None,
*,
instructions: Optional[str] = None,
msg_history: Optional[list[dict]] = None,
**kwargs
) -> str:
"""Custom LLM API wrapper.

At least one of prompt, instruction or msg_history should be provided.

Args:
prompt (str): The prompt to be passed to the LLM API
instruction (str): The instruction to be passed to the LLM API
msg_history (list[dict]): The message history to be passed to the LLM API
**kwargs: Any additional arguments to be passed to the LLM API

Returns:
str: The output of the LLM API
"""

# Call your LLM API here
# What you pass to the llm will depend on what arguments it accepts.
llm_output = some_llm(prompt, instructions, msg_history, **kwargs)

return llm_output

# Wrap your LLM API call
validated_response = guard(
my_llm_api,
prompt="Can you generate a list of 10 things that are not food?",
**kwargs,
)
```
4 changes: 2 additions & 2 deletions guardrails/applications/text2sql.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import asyncio
import json
import os
import openai
from string import Template
from typing import Callable, Dict, Optional, Type, cast

from guardrails.classes import ValidationOutcome
from guardrails.document_store import DocumentStoreBase, EphemeralDocumentStore
from guardrails.embedding import EmbeddingBase, OpenAIEmbedding
from guardrails.guard import Guard
from guardrails.utils.openai_utils import get_static_openai_create_func
from guardrails.utils.sql_utils import create_sql_driver
from guardrails.vectordb import Faiss, VectorDBBase

Expand Down Expand Up @@ -89,7 +89,7 @@ def __init__(
reask_prompt: Prompt to use for reasking. Defaults to REASK_PROMPT.
"""
if llm_api is None:
llm_api = get_static_openai_create_func()
llm_api = openai.completions.create

self.example_formatter = example_formatter
self.llm_api = llm_api
Expand Down
34 changes: 25 additions & 9 deletions guardrails/formatters/json_formatter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Optional, Union
from typing import Dict, List, Optional, Union

from guardrails.formatters.base_formatter import BaseFormatter
from guardrails.llm_providers import (
Expand Down Expand Up @@ -99,32 +99,48 @@ def wrap_callable(self, llm_callable) -> ArbitraryCallable:

if isinstance(llm_callable, HuggingFacePipelineCallable):
model = llm_callable.init_kwargs["pipeline"]
return ArbitraryCallable(
lambda p: json.dumps(

def fn(
prompt: str,
*args,
instructions: Optional[str] = None,
msg_history: Optional[List[Dict[str, str]]] = None,
**kwargs,
) -> str:
return json.dumps(
Jsonformer(
model=model.model,
tokenizer=model.tokenizer,
json_schema=self.output_schema,
prompt=p,
prompt=prompt,
)()
)
)

return ArbitraryCallable(fn)
elif isinstance(llm_callable, HuggingFaceModelCallable):
# This will not work because 'model_generate' is the .gen method.
# model = self.api.init_kwargs["model_generate"]
# Use the __self__ to grab the base mode for passing into JF.
model = llm_callable.init_kwargs["model_generate"].__self__
tokenizer = llm_callable.init_kwargs["tokenizer"]
return ArbitraryCallable(
lambda p: json.dumps(

def fn(
prompt: str,
*args,
instructions: Optional[str] = None,
msg_history: Optional[List[Dict[str, str]]] = None,
**kwargs,
) -> str:
return json.dumps(
Jsonformer(
model=model,
tokenizer=tokenizer,
json_schema=self.output_schema,
prompt=p,
prompt=prompt,
)()
)
)

return ArbitraryCallable(fn)
else:
raise ValueError(
"JsonFormatter can only be used with HuggingFace*Callable."
Expand Down
77 changes: 61 additions & 16 deletions guardrails/llm_providers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio

import inspect
from typing import (
Any,
Awaitable,
Expand Down Expand Up @@ -27,10 +28,10 @@
from guardrails.utils.openai_utils import (
AsyncOpenAIClient,
OpenAIClient,
get_static_openai_acreate_func,
get_static_openai_chat_acreate_func,
get_static_openai_chat_create_func,
get_static_openai_create_func,
is_static_openai_acreate_func,
is_static_openai_chat_acreate_func,
is_static_openai_chat_create_func,
is_static_openai_create_func,
)
from guardrails.utils.pydantic_utils import convert_pydantic_model_to_openai_fn
from guardrails.utils.safe_get import safe_get
Expand Down Expand Up @@ -711,6 +712,26 @@ def _invoke_llm(self, prompt: str, pipeline: Any, *args, **kwargs) -> LLMRespons

class ArbitraryCallable(PromptCallableBase):
def __init__(self, llm_api: Optional[Callable] = None, *args, **kwargs):
llm_api_args = inspect.getfullargspec(llm_api)
if not llm_api_args.args:
raise ValueError(
"Custom LLM callables must accept"
" at least one positional argument for prompt!"
)
if not llm_api_args.varkw:
raise ValueError("Custom LLM callables must accept **kwargs!")
if (
not llm_api_args.kwonlyargs
or "instructions" not in llm_api_args.kwonlyargs
or "msg_history" not in llm_api_args.kwonlyargs
):
warnings.warn(
"We recommend including 'instructions' and 'msg_history'"
" as keyword-only arguments for custom LLM callables."
" Doing so ensures these arguments are not uninentionally"
" passed through to other calls via **kwargs.",
UserWarning,
)
self.llm_api = llm_api
super().__init__(*args, **kwargs)

Expand Down Expand Up @@ -784,9 +805,9 @@ def get_llm_ask(
except ImportError:
pass

if llm_api == get_static_openai_create_func():
if is_static_openai_create_func(llm_api):
return OpenAICallable(*args, **kwargs)
if llm_api == get_static_openai_chat_create_func():
if is_static_openai_chat_create_func(llm_api):
return OpenAIChatCallable(*args, **kwargs)

try:
Expand Down Expand Up @@ -1190,6 +1211,26 @@ async def invoke_llm(

class AsyncArbitraryCallable(AsyncPromptCallableBase):
def __init__(self, llm_api: Callable, *args, **kwargs):
llm_api_args = inspect.getfullargspec(llm_api)
if not llm_api_args.args:
raise ValueError(
"Custom LLM callables must accept"
" at least one positional argument for prompt!"
)
if not llm_api_args.varkw:
raise ValueError("Custom LLM callables must accept **kwargs!")
if (
not llm_api_args.kwonlyargs
or "instructions" not in llm_api_args.kwonlyargs
or "msg_history" not in llm_api_args.kwonlyargs
):
warnings.warn(
"We recommend including 'instructions' and 'msg_history'"
" as keyword-only arguments for custom LLM callables."
" Doing so ensures these arguments are not uninentionally"
" passed through to other calls via **kwargs.",
UserWarning,
)
self.llm_api = llm_api
super().__init__(*args, **kwargs)

Expand Down Expand Up @@ -1241,7 +1282,7 @@ async def invoke_llm(self, *args, **kwargs) -> LLMResponse:


def get_async_llm_ask(
llm_api: Callable[[Any], Awaitable[Any]], *args, **kwargs
llm_api: Callable[..., Awaitable[Any]], *args, **kwargs
) -> AsyncPromptCallableBase:
try:
import litellm
Expand All @@ -1252,9 +1293,12 @@ def get_async_llm_ask(
pass

# these only work with openai v0 (None otherwise)
if llm_api == get_static_openai_acreate_func():
# We no longer support OpenAI v0
# We should drop these checks or update the logic to support
# OpenAI v1 clients instead of just static methods
if is_static_openai_acreate_func(llm_api):
return AsyncOpenAICallable(*args, **kwargs)
if llm_api == get_static_openai_chat_acreate_func():
if is_static_openai_chat_acreate_func(llm_api):
return AsyncOpenAIChatCallable(*args, **kwargs)

try:
Expand All @@ -1265,11 +1309,12 @@ def get_async_llm_ask(
except ImportError:
pass

return AsyncArbitraryCallable(*args, llm_api=llm_api, **kwargs)
if llm_api is not None:
return AsyncArbitraryCallable(*args, llm_api=llm_api, **kwargs)


def model_is_supported_server_side(
llm_api: Optional[Union[Callable, Callable[[Any], Awaitable[Any]]]] = None,
llm_api: Optional[Union[Callable, Callable[..., Awaitable[Any]]]] = None,
*args,
**kwargs,
) -> bool:
Expand All @@ -1289,17 +1334,17 @@ def model_is_supported_server_side(

# CONTINUOUS FIXME: Update with newly supported LLMs
def get_llm_api_enum(
llm_api: Callable[[Any], Awaitable[Any]], *args, **kwargs
llm_api: Callable[..., Awaitable[Any]], *args, **kwargs
) -> Optional[LLMResource]:
# TODO: Distinguish between v1 and v2
model = get_llm_ask(llm_api, *args, **kwargs)
if llm_api == get_static_openai_create_func():
if is_static_openai_create_func(llm_api):
return LLMResource.OPENAI_DOT_COMPLETION_DOT_CREATE
elif llm_api == get_static_openai_chat_create_func():
elif is_static_openai_chat_create_func(llm_api):
return LLMResource.OPENAI_DOT_CHAT_COMPLETION_DOT_CREATE
elif llm_api == get_static_openai_acreate_func():
elif is_static_openai_acreate_func(llm_api): # This is always False
return LLMResource.OPENAI_DOT_COMPLETION_DOT_ACREATE
elif llm_api == get_static_openai_chat_acreate_func():
elif is_static_openai_chat_acreate_func(llm_api): # This is always False
return LLMResource.OPENAI_DOT_CHAT_COMPLETION_DOT_ACREATE
elif isinstance(model, LiteLLMCallable):
return LLMResource.LITELLM_DOT_COMPLETION
Expand Down
16 changes: 8 additions & 8 deletions guardrails/utils/openai_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@
from .v1 import OpenAIClientV1 as OpenAIClient
from .v1 import (
OpenAIServiceUnavailableError,
get_static_openai_acreate_func,
get_static_openai_chat_acreate_func,
get_static_openai_chat_create_func,
get_static_openai_create_func,
is_static_openai_acreate_func,
is_static_openai_chat_acreate_func,
is_static_openai_chat_create_func,
is_static_openai_create_func,
)

__all__ = [
"AsyncOpenAIClient",
"OpenAIClient",
"get_static_openai_create_func",
"get_static_openai_chat_create_func",
"get_static_openai_acreate_func",
"get_static_openai_chat_acreate_func",
"is_static_openai_create_func",
"is_static_openai_chat_create_func",
"is_static_openai_acreate_func",
"is_static_openai_chat_acreate_func",
"OpenAIServiceUnavailableError",
]
28 changes: 19 additions & 9 deletions guardrails/utils/openai_utils/v1.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, AsyncIterable, Dict, Iterable, List, cast
from typing import Any, AsyncIterable, Callable, Dict, Iterable, List, Optional, cast

import openai

Expand All @@ -12,20 +12,30 @@
from guardrails.telemetry import trace_llm_call, trace_operation


def get_static_openai_create_func():
return openai.completions.create
def is_static_openai_create_func(llm_api: Optional[Callable]) -> bool:
try:
return llm_api == openai.completions.create
except openai.OpenAIError:
return False


def get_static_openai_chat_create_func():
return openai.chat.completions.create
def is_static_openai_chat_create_func(llm_api: Optional[Callable]) -> bool:
try:
return llm_api == openai.chat.completions.create
except openai.OpenAIError:
return False


def get_static_openai_acreate_func():
return None
def is_static_openai_acreate_func(llm_api: Optional[Callable]) -> bool:
# Because the static version of this does not exist in OpenAI 1.x
# Can we just drop these checks?
return False


def get_static_openai_chat_acreate_func():
return None
def is_static_openai_chat_acreate_func(llm_api: Optional[Callable]) -> bool:
# Because the static version of this does not exist in OpenAI 1.x
# Can we just drop these checks?
return False


OpenAIServiceUnavailableError = openai.APIError
Expand Down
21 changes: 21 additions & 0 deletions tests/integration_tests/test_assets/custom_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import Dict, List, Optional


def mock_llm(
prompt: Optional[str] = None,
*args,
instructions: Optional[str] = None,
msg_history: Optional[List[Dict[str, str]]] = None,
**kwargs,
) -> str:
return ""


async def mock_async_llm(
prompt: Optional[str] = None,
*args,
instructions: Optional[str] = None,
msg_history: Optional[List[Dict[str, str]]] = None,
**kwargs,
) -> str:
return ""
Loading
Loading