Skip to content

Commit 259d917

Browse files
committed
enforce method signature for custom llm callables
1 parent 7084245 commit 259d917

File tree

11 files changed

+455
-55
lines changed

11 files changed

+455
-55
lines changed

guardrails/formatters/json_formatter.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from typing import Optional, Union
2+
from typing import Dict, List, Optional, Union
33

44
from guardrails.formatters.base_formatter import BaseFormatter
55
from guardrails.llm_providers import (
@@ -99,32 +99,48 @@ def wrap_callable(self, llm_callable) -> ArbitraryCallable:
9999

100100
if isinstance(llm_callable, HuggingFacePipelineCallable):
101101
model = llm_callable.init_kwargs["pipeline"]
102-
return ArbitraryCallable(
103-
lambda p: json.dumps(
102+
103+
def fn(
104+
prompt: str,
105+
*args,
106+
instructions: Optional[str] = None,
107+
msg_history: Optional[List[Dict[str, str]]] = None,
108+
**kwargs,
109+
) -> str:
110+
return json.dumps(
104111
Jsonformer(
105112
model=model.model,
106113
tokenizer=model.tokenizer,
107114
json_schema=self.output_schema,
108-
prompt=p,
115+
prompt=prompt,
109116
)()
110117
)
111-
)
118+
119+
return ArbitraryCallable(fn)
112120
elif isinstance(llm_callable, HuggingFaceModelCallable):
113121
# This will not work because 'model_generate' is the .gen method.
114122
# model = self.api.init_kwargs["model_generate"]
115123
# Use the __self__ to grab the base mode for passing into JF.
116124
model = llm_callable.init_kwargs["model_generate"].__self__
117125
tokenizer = llm_callable.init_kwargs["tokenizer"]
118-
return ArbitraryCallable(
119-
lambda p: json.dumps(
126+
127+
def fn(
128+
prompt: str,
129+
*args,
130+
instructions: Optional[str] = None,
131+
msg_history: Optional[List[Dict[str, str]]] = None,
132+
**kwargs,
133+
) -> str:
134+
return json.dumps(
120135
Jsonformer(
121136
model=model,
122137
tokenizer=tokenizer,
123138
json_schema=self.output_schema,
124-
prompt=p,
139+
prompt=prompt,
125140
)()
126141
)
127-
)
142+
143+
return ArbitraryCallable(fn)
128144
else:
129145
raise ValueError(
130146
"JsonFormatter can only be used with HuggingFace*Callable."

guardrails/llm_providers.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22

3+
import inspect
34
from typing import (
45
Any,
56
Awaitable,
@@ -711,6 +712,26 @@ def _invoke_llm(self, prompt: str, pipeline: Any, *args, **kwargs) -> LLMRespons
711712

712713
class ArbitraryCallable(PromptCallableBase):
713714
def __init__(self, llm_api: Optional[Callable] = None, *args, **kwargs):
715+
llm_api_args = inspect.getfullargspec(llm_api)
716+
if not llm_api_args.args:
717+
raise ValueError(
718+
"Custom LLM callables must accept"
719+
" at least one positional argument for prompt!"
720+
)
721+
if not llm_api_args.varkw:
722+
raise ValueError("Custom LLM callables must accept **kwargs!")
723+
if (
724+
not llm_api_args.kwonlyargs
725+
or "instructions" not in llm_api_args.kwonlyargs
726+
or "msg_history" not in llm_api_args.kwonlyargs
727+
):
728+
warnings.warn(
729+
"We recommend including 'instructions' and 'msg_history'"
730+
" as keyword-only arguments for custom LLM callables."
731+
" Doing so ensures these arguments are not uninentionally"
732+
" passed through to other calls via **kwargs.",
733+
UserWarning,
734+
)
714735
self.llm_api = llm_api
715736
super().__init__(*args, **kwargs)
716737

@@ -1190,6 +1211,26 @@ async def invoke_llm(
11901211

11911212
class AsyncArbitraryCallable(AsyncPromptCallableBase):
11921213
def __init__(self, llm_api: Callable, *args, **kwargs):
1214+
llm_api_args = inspect.getfullargspec(llm_api)
1215+
if not llm_api_args.args:
1216+
raise ValueError(
1217+
"Custom LLM callables must accept"
1218+
" at least one positional argument for prompt!"
1219+
)
1220+
if not llm_api_args.varkw:
1221+
raise ValueError("Custom LLM callables must accept **kwargs!")
1222+
if (
1223+
not llm_api_args.kwonlyargs
1224+
or "instructions" not in llm_api_args.kwonlyargs
1225+
or "msg_history" not in llm_api_args.kwonlyargs
1226+
):
1227+
warnings.warn(
1228+
"We recommend including 'instructions' and 'msg_history'"
1229+
" as keyword-only arguments for custom LLM callables."
1230+
" Doing so ensures these arguments are not uninentionally"
1231+
" passed through to other calls via **kwargs.",
1232+
UserWarning,
1233+
)
11931234
self.llm_api = llm_api
11941235
super().__init__(*args, **kwargs)
11951236

@@ -1241,7 +1282,7 @@ async def invoke_llm(self, *args, **kwargs) -> LLMResponse:
12411282

12421283

12431284
def get_async_llm_ask(
1244-
llm_api: Callable[[Any], Awaitable[Any]], *args, **kwargs
1285+
llm_api: Callable[..., Awaitable[Any]], *args, **kwargs
12451286
) -> AsyncPromptCallableBase:
12461287
try:
12471288
import litellm
@@ -1268,11 +1309,12 @@ def get_async_llm_ask(
12681309
except ImportError:
12691310
pass
12701311

1271-
return AsyncArbitraryCallable(*args, llm_api=llm_api, **kwargs)
1312+
if llm_api is not None:
1313+
return AsyncArbitraryCallable(*args, llm_api=llm_api, **kwargs)
12721314

12731315

12741316
def model_is_supported_server_side(
1275-
llm_api: Optional[Union[Callable, Callable[[Any], Awaitable[Any]]]] = None,
1317+
llm_api: Optional[Union[Callable, Callable[..., Awaitable[Any]]]] = None,
12761318
*args,
12771319
**kwargs,
12781320
) -> bool:
@@ -1292,7 +1334,7 @@ def model_is_supported_server_side(
12921334

12931335
# CONTINUOUS FIXME: Update with newly supported LLMs
12941336
def get_llm_api_enum(
1295-
llm_api: Callable[[Any], Awaitable[Any]], *args, **kwargs
1337+
llm_api: Callable[..., Awaitable[Any]], *args, **kwargs
12961338
) -> Optional[LLMResource]:
12971339
# TODO: Distinguish between v1 and v2
12981340
model = get_llm_ask(llm_api, *args, **kwargs)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from typing import Dict, List, Optional
2+
3+
4+
def mock_llm(
5+
prompt: Optional[str] = None,
6+
*args,
7+
instructions: Optional[str] = None,
8+
msg_history: Optional[List[Dict[str, str]]] = None,
9+
**kwargs,
10+
) -> str:
11+
return ""
12+
13+
14+
async def mock_async_llm(
15+
prompt: Optional[str] = None,
16+
*args,
17+
instructions: Optional[str] = None,
18+
msg_history: Optional[List[Dict[str, str]]] = None,
19+
**kwargs,
20+
) -> str:
21+
return ""

tests/integration_tests/test_async.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from guardrails import AsyncGuard, Prompt
44
from guardrails.utils import docs_utils
55
from guardrails.classes.llm.llm_response import LLMResponse
6+
from tests.integration_tests.test_assets.custom_llm import mock_async_llm
67
from tests.integration_tests.test_assets.fixtures import ( # noqa
78
fixture_llm_output,
89
fixture_rail_spec,
@@ -12,10 +13,6 @@
1213
from .mock_llm_outputs import entity_extraction
1314

1415

15-
async def mock_llm(*args, **kwargs):
16-
return ""
17-
18-
1916
@pytest.mark.asyncio
2017
async def test_entity_extraction_with_reask(mocker):
2118
"""Test that the entity extraction works with re-asking."""
@@ -45,7 +42,7 @@ async def test_entity_extraction_with_reask(mocker):
4542
preprocess_prompt_spy = mocker.spy(async_runner, "preprocess_prompt")
4643

4744
final_output = await guard(
48-
llm_api=mock_llm,
45+
llm_api=mock_async_llm,
4946
prompt_params={"document": content[:6000]},
5047
num_reasks=1,
5148
)
@@ -104,7 +101,7 @@ async def test_entity_extraction_with_noop(mocker):
104101
content = docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf")
105102
guard = AsyncGuard.from_rail_string(entity_extraction.RAIL_SPEC_WITH_NOOP)
106103
final_output = await guard(
107-
llm_api=mock_llm,
104+
llm_api=mock_async_llm,
108105
prompt_params={"document": content[:6000]},
109106
num_reasks=1,
110107
)
@@ -151,7 +148,7 @@ async def test_entity_extraction_with_noop_pydantic(mocker):
151148
prompt=entity_extraction.PYDANTIC_PROMPT,
152149
)
153150
final_output = await guard(
154-
llm_api=mock_llm,
151+
llm_api=mock_async_llm,
155152
prompt_params={"document": content[:6000]},
156153
num_reasks=1,
157154
)
@@ -192,7 +189,7 @@ async def test_entity_extraction_with_filter(mocker):
192189
content = docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf")
193190
guard = AsyncGuard.from_rail_string(entity_extraction.RAIL_SPEC_WITH_FILTER)
194191
final_output = await guard(
195-
llm_api=mock_llm,
192+
llm_api=mock_async_llm,
196193
prompt_params={"document": content[:6000]},
197194
num_reasks=1,
198195
)
@@ -232,7 +229,7 @@ async def test_entity_extraction_with_fix(mocker):
232229
content = docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf")
233230
guard = AsyncGuard.from_rail_string(entity_extraction.RAIL_SPEC_WITH_FIX)
234231
final_output = await guard(
235-
llm_api=mock_llm,
232+
llm_api=mock_async_llm,
236233
prompt_params={"document": content[:6000]},
237234
num_reasks=1,
238235
)
@@ -269,7 +266,7 @@ async def test_entity_extraction_with_refrain(mocker):
269266
content = docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf")
270267
guard = AsyncGuard.from_rail_string(entity_extraction.RAIL_SPEC_WITH_REFRAIN)
271268
final_output = await guard(
272-
llm_api=mock_llm,
269+
llm_api=mock_async_llm,
273270
prompt_params={"document": content[:6000]},
274271
num_reasks=1,
275272
)
@@ -295,7 +292,7 @@ async def test_rail_spec_output_parse(rail_spec, llm_output, validated_output):
295292
guard = AsyncGuard.from_rail_string(rail_spec)
296293
output = await guard.parse(
297294
llm_output,
298-
llm_api=mock_llm,
295+
llm_api=mock_async_llm,
299296
)
300297
assert output.validated_output == validated_output
301298

@@ -334,7 +331,7 @@ async def test_string_rail_spec_output_parse(
334331
guard: AsyncGuard = AsyncGuard.from_rail_string(string_rail_spec)
335332
output = await guard.parse(
336333
string_llm_output,
337-
llm_api=mock_llm,
334+
llm_api=mock_async_llm,
338335
num_reasks=0,
339336
)
340337
assert output.validated_output == validated_string_output

0 commit comments

Comments
 (0)