Skip to content

Commit 653ae23

Browse files
committed
type and test fixes
1 parent 1a39abb commit 653ae23

File tree

2 files changed

+17
-18
lines changed

2 files changed

+17
-18
lines changed

guardrails/integrations/llama_index/guardrails_engine.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
1-
# pyright: reportMissingImports=false
2-
3-
from typing import Any, Optional, Dict, List, Union, TYPE_CHECKING, cast
1+
from typing import Any, Optional, Dict, List, Union, cast
42
from guardrails import Guard
53
from guardrails.errors import ValidationError
64
from guardrails.classes.validation_outcome import ValidationOutcome
75
from guardrails.decorators.experimental import experimental
8-
import importlib.util
96

10-
LLAMA_INDEX_AVAILABLE = importlib.util.find_spec("llama_index") is not None
117

12-
if TYPE_CHECKING or LLAMA_INDEX_AVAILABLE:
8+
try:
9+
import llama_index # noqa: F401
1310
from llama_index.core.query_engine import BaseQueryEngine
1411
from llama_index.core.chat_engine.types import (
1512
BaseChatEngine,
@@ -28,6 +25,11 @@
2825
)
2926
from llama_index.core.base.llms.types import ChatMessage
3027
from llama_index.core.prompts.mixin import PromptMixinType
28+
except ImportError:
29+
raise ImportError(
30+
"llama_index is not installed. Please install it with "
31+
"`pip install llama-index` to use GuardrailsEngine."
32+
)
3133

3234

3335
class GuardrailsEngine(BaseQueryEngine, BaseChatEngine):
@@ -38,14 +40,6 @@ def __init__(
3840
guard_kwargs: Optional[Dict[str, Any]] = None,
3941
callback_manager: Optional["CallbackManager"] = None,
4042
):
41-
try:
42-
import llama_index # noqa: F401
43-
except ImportError:
44-
raise ImportError(
45-
"llama_index is not installed. Please install it with "
46-
"`pip install llama-index` to use GuardrailsEngine."
47-
)
48-
4943
self._engine = engine
5044
self._guard = guard
5145
self._guard_kwargs = guard_kwargs or {}
@@ -196,12 +190,15 @@ def _create_chat_response(
196190
if self._engine_response.metadata is None:
197191
self._engine_response.metadata = {}
198192
self._engine_response.metadata.update(metadata_update)
193+
# Repeat for typing purposes
194+
self._engine_response.response = content
195+
return self._engine_response
199196
elif isinstance(self._engine_response, StreamingAgentChatResponse):
200197
for key, value in metadata_update.items():
201198
setattr(self._engine_response, key, value)
202-
203-
self._engine_response.response = content
204-
return self._engine_response
199+
# Repeat for typing purposes
200+
self._engine_response.response = content
201+
return self._engine_response
205202

206203
async def _aquery(self, query_bundle: "QueryBundle") -> "RESPONSE_TYPE":
207204
"""Async version of _query."""

tests/integration_tests/integrations/llama_index/test_guardrails_engine.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,4 +108,6 @@ class UnsupportedEngine:
108108
guardrails_engine = GuardrailsEngine(engine, guard)
109109

110110
with pytest.raises(ValueError, match="Unsupported engine type"):
111-
guardrails_engine.engine_api("Test prompt")
111+
guardrails_engine.engine_api(
112+
messages=[{"role": "user", "content": "Test prompt"}]
113+
)

0 commit comments

Comments
 (0)