Skip to content

Commit 58d16ec

Browse files
committed
update some tests
1 parent 10a1dde commit 58d16ec

File tree

3 files changed

+20
-6
lines changed

3 files changed

+20
-6
lines changed

guardrails/llm_providers.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,15 @@ def _invoke_llm(
379379

380380

381381
class HuggingFacePipelineCallable(PromptCallableBase):
382-
def _invoke_llm(self, prompt: str, pipeline: Any, *args, **kwargs) -> LLMResponse:
382+
def _invoke_llm(
383+
self,
384+
pipeline: Any,
385+
*args,
386+
messages: Union[
387+
list[dict[str, Union[str, Prompt, Instructions]]], MessageHistory
388+
],
389+
**kwargs,
390+
) -> LLMResponse:
383391
try:
384392
import transformers # noqa: F401 # type: ignore
385393
except ImportError:
@@ -400,7 +408,7 @@ def _invoke_llm(self, prompt: str, pipeline: Any, *args, **kwargs) -> LLMRespons
400408
temperature = kwargs.pop("temperature", None)
401409
if temperature == 0:
402410
temperature = None
403-
411+
prompt = messages_string(messages)
404412
trace_operation(
405413
input_mime_type="application/json",
406414
input_value={

tests/integration_tests/test_guard.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -994,15 +994,17 @@ def custom_llm(
994994
guard = gd.Guard.for_pydantic(Task)
995995
_, dict_o, *rest = guard(
996996
custom_llm,
997-
prompt="What is the status of this task?",
997+
messages=[{"role": "user", "content": "What is the status of this task?"}],
998998
)
999999
assert dict_o == {"status": "not started"}
10001000

10011001
return_value = pydantic.LLM_OUTPUT_ENUM_2
10021002
guard = gd.Guard.for_pydantic(Task)
10031003
result = guard(
10041004
custom_llm,
1005-
prompt="What is the status of this task REALLY?",
1005+
messages=[
1006+
{"role": "user", "content": "What is the status of this task REALLY?"}
1007+
],
10061008
num_reasks=0,
10071009
)
10081010

@@ -1351,7 +1353,9 @@ def test_guard_for_pydantic_with_mock_hf_pipeline():
13511353

13521354
pipe = make_mock_pipeline()
13531355
guard = Guard()
1354-
_ = guard(pipe, prompt="Don't care about the output.")
1356+
_ = guard(
1357+
pipe, messages=[{"role": "user", "content": "Don't care about the output."}]
1358+
)
13551359

13561360

13571361
@pytest.mark.skipif(

tests/unit_tests/test_llm_providers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,9 @@ def test_hugging_face_pipeline_callable():
258258
from guardrails.llm_providers import HuggingFacePipelineCallable
259259

260260
hf_model_callable = HuggingFacePipelineCallable()
261-
response = hf_model_callable("Hello", pipeline=pipeline)
261+
response = hf_model_callable(
262+
pipeline=pipeline, messages=[{"role": "user", "content": "Hello"}]
263+
)
262264

263265
assert isinstance(response, LLMResponse) is True
264266
assert response.output == "Hello there!"

0 commit comments

Comments
 (0)