1- # pyright: reportMissingImports=false
2-
3- from typing import Any , Optional , Dict , List , Union , TYPE_CHECKING , cast
1+ from typing import Any , Optional , Dict , List , Union , cast
42from guardrails import Guard
53from guardrails .errors import ValidationError
64from guardrails .classes .validation_outcome import ValidationOutcome
75from guardrails .decorators .experimental import experimental
8- import importlib .util
96
10- LLAMA_INDEX_AVAILABLE = importlib .util .find_spec ("llama_index" ) is not None
117
12- if TYPE_CHECKING or LLAMA_INDEX_AVAILABLE :
8+ try :
9+ import llama_index # noqa: F401
1310 from llama_index .core .query_engine import BaseQueryEngine
1411 from llama_index .core .chat_engine .types import (
1512 BaseChatEngine ,
2825 )
2926 from llama_index .core .base .llms .types import ChatMessage
3027 from llama_index .core .prompts .mixin import PromptMixinType
28+ except ImportError :
29+ raise ImportError (
30+ "llama_index is not installed. Please install it with "
31+ "`pip install llama-index` to use GuardrailsEngine."
32+ )
3133
3234
3335class GuardrailsEngine (BaseQueryEngine , BaseChatEngine ):
@@ -38,14 +40,6 @@ def __init__(
3840 guard_kwargs : Optional [Dict [str , Any ]] = None ,
3941 callback_manager : Optional ["CallbackManager" ] = None ,
4042 ):
41- try :
42- import llama_index # noqa: F401
43- except ImportError :
44- raise ImportError (
45- "llama_index is not installed. Please install it with "
46- "`pip install llama-index` to use GuardrailsEngine."
47- )
48-
4943 self ._engine = engine
5044 self ._guard = guard
5145 self ._guard_kwargs = guard_kwargs or {}
@@ -196,12 +190,15 @@ def _create_chat_response(
196190 if self ._engine_response .metadata is None :
197191 self ._engine_response .metadata = {}
198192 self ._engine_response .metadata .update (metadata_update )
193+ # Repeat for typing purposes
194+ self ._engine_response .response = content
195+ return self ._engine_response
199196 elif isinstance (self ._engine_response , StreamingAgentChatResponse ):
200197 for key , value in metadata_update .items ():
201198 setattr (self ._engine_response , key , value )
202-
203- self ._engine_response .response = content
204- return self ._engine_response
199+ # Repeat for typing purposes
200+ self ._engine_response .response = content
201+ return self ._engine_response
205202
206203 async def _aquery (self , query_bundle : "QueryBundle" ) -> "RESPONSE_TYPE" :
207204 """Async version of _query."""
0 commit comments