|
| 1 | +"""LangChain implementation of AIProvider for LaunchDarkly AI SDK.""" |
| 2 | + |
| 3 | +from typing import Any, Dict, List, Optional |
| 4 | + |
| 5 | +from langchain_core.chat_models import BaseChatModel |
| 6 | +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage |
| 7 | + |
| 8 | +from ldai.models import AIConfigKind, LDMessage |
| 9 | +from ldai.providers.ai_provider import AIProvider |
| 10 | +from ldai.providers.types import ChatResponse, LDAIMetrics, StructuredResponse |
| 11 | +from ldai.tracker import TokenUsage |
| 12 | + |
| 13 | + |
| 14 | +class LangChainProvider(AIProvider): |
| 15 | + """ |
| 16 | + LangChain implementation of AIProvider. |
| 17 | + |
| 18 | + This provider integrates LangChain models with LaunchDarkly's tracking capabilities. |
| 19 | + """ |
| 20 | + |
| 21 | + def __init__(self, llm: BaseChatModel, logger: Optional[Any] = None): |
| 22 | + """ |
| 23 | + Initialize the LangChain provider. |
| 24 | + |
| 25 | + :param llm: LangChain BaseChatModel instance |
| 26 | + :param logger: Optional logger for logging provider operations |
| 27 | + """ |
| 28 | + super().__init__(logger) |
| 29 | + self._llm = llm |
| 30 | + |
| 31 | + # ============================================================================= |
| 32 | + # MAIN FACTORY METHOD |
| 33 | + # ============================================================================= |
| 34 | + |
| 35 | + @staticmethod |
| 36 | + async def create(ai_config: AIConfigKind, logger: Optional[Any] = None) -> 'LangChainProvider': |
| 37 | + """ |
| 38 | + Static factory method to create a LangChain AIProvider from an AI configuration. |
| 39 | + |
| 40 | + :param ai_config: The LaunchDarkly AI configuration |
| 41 | + :param logger: Optional logger for the provider |
| 42 | + :return: Configured LangChainProvider instance |
| 43 | + """ |
| 44 | + llm = await LangChainProvider.create_langchain_model(ai_config) |
| 45 | + return LangChainProvider(llm, logger) |
| 46 | + |
| 47 | + # ============================================================================= |
| 48 | + # INSTANCE METHODS (AIProvider Implementation) |
| 49 | + # ============================================================================= |
| 50 | + |
| 51 | + async def invoke_model(self, messages: List[LDMessage]) -> ChatResponse: |
| 52 | + """ |
| 53 | + Invoke the LangChain model with an array of messages. |
| 54 | + |
| 55 | + :param messages: Array of LDMessage objects representing the conversation |
| 56 | + :return: ChatResponse containing the model's response |
| 57 | + """ |
| 58 | + try: |
| 59 | + # Convert LDMessage[] to LangChain messages |
| 60 | + langchain_messages = LangChainProvider.convert_messages_to_langchain(messages) |
| 61 | + |
| 62 | + # Get the LangChain response |
| 63 | + response: AIMessage = await self._llm.ainvoke(langchain_messages) |
| 64 | + |
| 65 | + # Generate metrics early (assumes success by default) |
| 66 | + metrics = LangChainProvider.get_ai_metrics_from_response(response) |
| 67 | + |
| 68 | + # Extract text content from the response |
| 69 | + content: str = '' |
| 70 | + if isinstance(response.content, str): |
| 71 | + content = response.content |
| 72 | + else: |
| 73 | + # Log warning for non-string content (likely multimodal) |
| 74 | + if self.logger: |
| 75 | + self.logger.warn( |
| 76 | + f"Multimodal response not supported, expecting a string. " |
| 77 | + f"Content type: {type(response.content)}, Content: {response.content}" |
| 78 | + ) |
| 79 | + # Update metrics to reflect content loss |
| 80 | + metrics.success = False |
| 81 | + |
| 82 | + # Create the assistant message |
| 83 | + from ldai.models import LDMessage |
| 84 | + assistant_message = LDMessage(role='assistant', content=content) |
| 85 | + |
| 86 | + return ChatResponse( |
| 87 | + message=assistant_message, |
| 88 | + metrics=metrics, |
| 89 | + ) |
| 90 | + except Exception as error: |
| 91 | + if self.logger: |
| 92 | + self.logger.warn(f'LangChain model invocation failed: {error}') |
| 93 | + |
| 94 | + from ldai.models import LDMessage |
| 95 | + return ChatResponse( |
| 96 | + message=LDMessage(role='assistant', content=''), |
| 97 | + metrics=LDAIMetrics(success=False, usage=None), |
| 98 | + ) |
| 99 | + |
| 100 | + async def invoke_structured_model( |
| 101 | + self, |
| 102 | + messages: List[LDMessage], |
| 103 | + response_structure: Dict[str, Any], |
| 104 | + ) -> StructuredResponse: |
| 105 | + """ |
| 106 | + Invoke the LangChain model with structured output support. |
| 107 | + |
| 108 | + :param messages: Array of LDMessage objects representing the conversation |
| 109 | + :param response_structure: Dictionary of output configurations keyed by output name |
| 110 | + :return: StructuredResponse containing the structured data |
| 111 | + """ |
| 112 | + try: |
| 113 | + # Convert LDMessage[] to LangChain messages |
| 114 | + langchain_messages = LangChainProvider.convert_messages_to_langchain(messages) |
| 115 | + |
| 116 | + # Get the LangChain response with structured output |
| 117 | + # Note: with_structured_output is available on BaseChatModel in newer LangChain versions |
| 118 | + if hasattr(self._llm, 'with_structured_output'): |
| 119 | + structured_llm = self._llm.with_structured_output(response_structure) |
| 120 | + response = await structured_llm.ainvoke(langchain_messages) |
| 121 | + else: |
| 122 | + # Fallback: invoke normally and try to parse as JSON |
| 123 | + response_obj = await self._llm.ainvoke(langchain_messages) |
| 124 | + if isinstance(response_obj, AIMessage): |
| 125 | + import json |
| 126 | + try: |
| 127 | + response = json.loads(response_obj.content) |
| 128 | + except json.JSONDecodeError: |
| 129 | + response = {'content': response_obj.content} |
| 130 | + else: |
| 131 | + response = response_obj |
| 132 | + |
| 133 | + # Using structured output doesn't support metrics |
| 134 | + metrics = LDAIMetrics( |
| 135 | + success=True, |
| 136 | + usage=TokenUsage(total=0, input=0, output=0), |
| 137 | + ) |
| 138 | + |
| 139 | + import json |
| 140 | + return StructuredResponse( |
| 141 | + data=response if isinstance(response, dict) else {'result': response}, |
| 142 | + raw_response=json.dumps(response) if not isinstance(response, str) else response, |
| 143 | + metrics=metrics, |
| 144 | + ) |
| 145 | + except Exception as error: |
| 146 | + if self.logger: |
| 147 | + self.logger.warn(f'LangChain structured model invocation failed: {error}') |
| 148 | + |
| 149 | + return StructuredResponse( |
| 150 | + data={}, |
| 151 | + raw_response='', |
| 152 | + metrics=LDAIMetrics( |
| 153 | + success=False, |
| 154 | + usage=TokenUsage(total=0, input=0, output=0), |
| 155 | + ), |
| 156 | + ) |
| 157 | + |
| 158 | + def get_chat_model(self) -> BaseChatModel: |
| 159 | + """ |
| 160 | + Get the underlying LangChain model instance. |
| 161 | + |
| 162 | + :return: The LangChain BaseChatModel instance |
| 163 | + """ |
| 164 | + return self._llm |
| 165 | + |
| 166 | + # ============================================================================= |
| 167 | + # STATIC UTILITY METHODS |
| 168 | + # ============================================================================= |
| 169 | + |
| 170 | + @staticmethod |
| 171 | + def map_provider(ld_provider_name: str) -> str: |
| 172 | + """ |
| 173 | + Map LaunchDarkly provider names to LangChain provider names. |
| 174 | + |
| 175 | + This method enables seamless integration between LaunchDarkly's standardized |
| 176 | + provider naming and LangChain's naming conventions. |
| 177 | + |
| 178 | + :param ld_provider_name: LaunchDarkly provider name |
| 179 | + :return: LangChain provider name |
| 180 | + """ |
| 181 | + lowercased_name = ld_provider_name.lower() |
| 182 | + |
| 183 | + mapping: Dict[str, str] = { |
| 184 | + 'gemini': 'google-genai', |
| 185 | + } |
| 186 | + |
| 187 | + return mapping.get(lowercased_name, lowercased_name) |
| 188 | + |
| 189 | + @staticmethod |
| 190 | + def get_ai_metrics_from_response(response: AIMessage) -> LDAIMetrics: |
| 191 | + """ |
| 192 | + Get AI metrics from a LangChain provider response. |
| 193 | + |
| 194 | + This method extracts token usage information and success status from LangChain responses |
| 195 | + and returns a LaunchDarkly LDAIMetrics object. |
| 196 | + |
| 197 | + :param response: The response from the LangChain model |
| 198 | + :return: LDAIMetrics with success status and token usage |
| 199 | + """ |
| 200 | + # Extract token usage if available |
| 201 | + usage: Optional[TokenUsage] = None |
| 202 | + if hasattr(response, 'response_metadata') and response.response_metadata: |
| 203 | + token_usage = response.response_metadata.get('token_usage') |
| 204 | + if token_usage: |
| 205 | + usage = TokenUsage( |
| 206 | + total=token_usage.get('total_tokens', 0) or token_usage.get('totalTokens', 0) or 0, |
| 207 | + input=token_usage.get('prompt_tokens', 0) or token_usage.get('promptTokens', 0) or 0, |
| 208 | + output=token_usage.get('completion_tokens', 0) or token_usage.get('completionTokens', 0) or 0, |
| 209 | + ) |
| 210 | + |
| 211 | + # LangChain responses that complete successfully are considered successful by default |
| 212 | + return LDAIMetrics(success=True, usage=usage) |
| 213 | + |
| 214 | + @staticmethod |
| 215 | + def convert_messages_to_langchain(messages: List[LDMessage]) -> List[BaseMessage]: |
| 216 | + """ |
| 217 | + Convert LaunchDarkly messages to LangChain messages. |
| 218 | + |
| 219 | + This helper method enables developers to work directly with LangChain message types |
| 220 | + while maintaining compatibility with LaunchDarkly's standardized message format. |
| 221 | + |
| 222 | + :param messages: List of LDMessage objects |
| 223 | + :return: List of LangChain message objects |
| 224 | + """ |
| 225 | + result: List[BaseMessage] = [] |
| 226 | + for msg in messages: |
| 227 | + if msg.role == 'system': |
| 228 | + result.append(SystemMessage(content=msg.content)) |
| 229 | + elif msg.role == 'user': |
| 230 | + result.append(HumanMessage(content=msg.content)) |
| 231 | + elif msg.role == 'assistant': |
| 232 | + result.append(AIMessage(content=msg.content)) |
| 233 | + else: |
| 234 | + raise ValueError(f'Unsupported message role: {msg.role}') |
| 235 | + return result |
| 236 | + |
| 237 | + @staticmethod |
| 238 | + async def create_langchain_model(ai_config: AIConfigKind) -> BaseChatModel: |
| 239 | + """ |
| 240 | + Create a LangChain model from an AI configuration. |
| 241 | + |
| 242 | + This public helper method enables developers to initialize their own LangChain models |
| 243 | + using LaunchDarkly AI configurations. |
| 244 | + |
| 245 | + :param ai_config: The LaunchDarkly AI configuration |
| 246 | + :return: A configured LangChain BaseChatModel |
| 247 | + """ |
| 248 | + model_name = ai_config.model.name if ai_config.model else '' |
| 249 | + provider = ai_config.provider.name if ai_config.provider else '' |
| 250 | + parameters = ai_config.model.get_parameter('parameters') if ai_config.model else {} |
| 251 | + if not isinstance(parameters, dict): |
| 252 | + parameters = {} |
| 253 | + |
| 254 | + # Use LangChain's init_chat_model to support multiple providers |
| 255 | + # Note: This requires langchain package to be installed |
| 256 | + try: |
| 257 | + # Try to import init_chat_model from langchain.chat_models |
| 258 | + # This is available in langchain >= 0.1.0 |
| 259 | + try: |
| 260 | + from langchain.chat_models import init_chat_model |
| 261 | + except ImportError: |
| 262 | + # Fallback for older versions or different import path |
| 263 | + from langchain.chat_models.universal import init_chat_model |
| 264 | + |
| 265 | + # Map provider name |
| 266 | + langchain_provider = LangChainProvider.map_provider(provider) |
| 267 | + |
| 268 | + # Create model configuration |
| 269 | + model_kwargs = {**parameters} |
| 270 | + if langchain_provider: |
| 271 | + model_kwargs['model_provider'] = langchain_provider |
| 272 | + |
| 273 | + # Initialize the chat model (init_chat_model may be async or sync) |
| 274 | + result = init_chat_model(model_name, **model_kwargs) |
| 275 | + # Handle both sync and async initialization |
| 276 | + if hasattr(result, '__await__'): |
| 277 | + return await result |
| 278 | + return result |
| 279 | + except ImportError as e: |
| 280 | + raise ImportError( |
| 281 | + 'langchain package is required for LangChainProvider. ' |
| 282 | + 'Install it with: pip install langchain langchain-core' |
| 283 | + ) from e |
| 284 | + |
0 commit comments