Skip to content

Commit e3f801b

Browse files
authored
Merge pull request #706 from guardrails-ai/feat/llm-inputs-lite2
Feat/llm inputs lite2
2 parents 78ece30 + 239f3bd commit e3f801b

File tree

7 files changed

+595
-598
lines changed

7 files changed

+595
-598
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ jobs:
5050
matrix:
5151
python-version: ["3.8", "3.9", "3.10", "3.11"]
5252
pydantic-version: ["1.10.9", "2.4.2"]
53-
openai-version: ["0.28.1", "1.2.4"]
53+
openai-version: ["1.2.4"]
5454
steps:
5555
- uses: actions/checkout@v4
5656
- name: Set up Python ${{ matrix.python-version }}
@@ -87,7 +87,7 @@ jobs:
8787
# dependencies: ['dev', 'full']
8888
dependencies: ["full"]
8989
pydantic-version: ["1.10.9", "2.4.2"]
90-
openai-version: ["0.28.1", "1.2.4"]
90+
openai-version: ["1.2.4"]
9191
steps:
9292
- uses: actions/checkout@v4
9393
- name: Set up Python ${{ matrix.python-version }}

guardrails/llm_providers.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,16 @@ def _invoke_llm(
395395
*args,
396396
**kwargs,
397397
)
398+
399+
if kwargs.get("stream", False):
400+
# If stream is defined and set to True,
401+
# the callable returns a generator object
402+
llm_response = cast(Iterable[str], response)
403+
return LLMResponse(
404+
output="",
405+
stream_output=llm_response,
406+
)
407+
398408
return LLMResponse(
399409
output=response.choices[0].message.content, # type: ignore
400410
prompt_token_count=response.usage.prompt_tokens, # type: ignore
@@ -782,6 +792,53 @@ async def invoke_llm(
782792
)
783793

784794

795+
class AsyncLiteLLMCallable(AsyncPromptCallableBase):
796+
async def invoke_llm(
797+
self,
798+
text: str,
799+
instructions: Optional[str] = None,
800+
*args,
801+
**kwargs,
802+
):
803+
"""Wrapper for Lite LLM completions.
804+
805+
To use Lite LLM for guardrails, do
806+
```
807+
from litellm import completion
808+
809+
raw_llm_response, validated_response = guard(
810+
completion,
811+
model="gpt-3.5-turbo",
812+
prompt_params={...},
813+
temperature=...,
814+
...
815+
)
816+
```
817+
"""
818+
try:
819+
from litellm import acompletion # type: ignore
820+
except ImportError as e:
821+
raise PromptCallableException(
822+
"The `litellm` package is not installed. "
823+
"Install with `pip install litellm`"
824+
) from e
825+
826+
response = await acompletion(
827+
messages=litellm_messages(
828+
prompt=text,
829+
instructions=instructions,
830+
),
831+
*args,
832+
**kwargs,
833+
)
834+
835+
return LLMResponse(
836+
output=response.choices[0].message.content, # type: ignore
837+
prompt_token_count=response.usage.prompt_tokens, # type: ignore
838+
response_token_count=response.usage.completion_tokens, # type: ignore
839+
)
840+
841+
785842
class AsyncManifestCallable(AsyncPromptCallableBase):
786843
async def invoke_llm(
787844
self,
@@ -860,6 +917,14 @@ def get_async_llm_ask(
860917
except ImportError:
861918
pass
862919

920+
try:
921+
import litellm
922+
923+
if llm_api == litellm.acompletion:
924+
return AsyncLiteLLMCallable(*args, **kwargs)
925+
except ImportError:
926+
pass
927+
863928
return AsyncArbitraryCallable(*args, llm_api=llm_api, **kwargs)
864929

865930

guardrails/run/stream_runner.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from guardrails.classes.validation_outcome import ValidationOutcome
66
from guardrails.datatypes import verify_metadata_requirements
77
from guardrails.llm_providers import (
8+
LiteLLMCallable,
89
OpenAICallable,
910
OpenAIChatCallable,
1011
PromptCallableBase,
@@ -227,6 +228,11 @@ def get_chunk_text(self, chunk: Any, api: Union[PromptCallableBase, None]) -> st
227228
content = chunk.choices[0].delta.content
228229
if not finished and content:
229230
chunk_text = content
231+
elif isinstance(api, LiteLLMCallable):
232+
finished = chunk.choices[0].finish_reason
233+
content = chunk.choices[0].delta.content
234+
if not finished and content:
235+
chunk_text = content
230236
else:
231237
try:
232238
chunk_text = chunk

0 commit comments

Comments
 (0)