Skip to content

Commit 6ee62b4

Browse files
committed
create a langchain implementation of the ai provider
1 parent 8271807 commit 6ee62b4

File tree

3 files changed

+316
-9
lines changed

3 files changed

+316
-9
lines changed

ldai/providers/__init__.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,19 @@
33
from ldai.providers.ai_provider import AIProvider
44
from ldai.providers.ai_provider_factory import AIProviderFactory, SupportedAIProvider
55

6-
__all__ = [
7-
'AIProvider',
8-
'AIProviderFactory',
9-
'SupportedAIProvider',
10-
]
6+
# Export LangChain provider if available
7+
try:
8+
from ldai.providers.langchain import LangChainProvider
9+
__all__ = [
10+
'AIProvider',
11+
'AIProviderFactory',
12+
'LangChainProvider',
13+
'SupportedAIProvider',
14+
]
15+
except ImportError:
16+
__all__ = [
17+
'AIProvider',
18+
'AIProviderFactory',
19+
'SupportedAIProvider',
20+
]
1121

ldai/providers/ai_provider_factory.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def _get_providers_to_try(
8080
provider_set.add(provider_name) # type: ignore
8181

8282
# Then try multi-provider packages, but avoid duplicates
83-
multi_provider_packages: List[SupportedAIProvider] = ['langchain', 'vercel']
83+
multi_provider_packages: List[SupportedAIProvider] = ['langchain']
8484
for provider in multi_provider_packages:
8585
provider_set.add(provider)
8686

@@ -100,10 +100,23 @@ async def _try_create_provider(
100100
:param logger: Optional logger
101101
:return: AIProvider instance or None if creation failed
102102
"""
103+
# Handle built-in providers (part of this package)
104+
if provider_type == 'langchain':
105+
try:
106+
from ldai.providers.langchain import LangChainProvider
107+
return await LangChainProvider.create(ai_config, logger)
108+
except ImportError as error:
109+
if logger:
110+
logger.warn(
111+
f"Error creating LangChainProvider: {error}. "
112+
f"Make sure langchain and langchain-core packages are installed."
113+
)
114+
return None
115+
116+
# For future external providers, use dynamic import
103117
provider_mappings = {
104-
'openai': ('launchdarkly_server_sdk_ai_openai', 'OpenAIProvider'),
105-
'langchain': ('launchdarkly_server_sdk_ai_langchain', 'LangChainProvider'),
106-
'vercel': ('launchdarkly_server_sdk_ai_vercel', 'VercelProvider'),
118+
# 'openai': ('launchdarkly_server_sdk_ai_openai', 'OpenAIProvider'),
119+
# 'vercel': ('launchdarkly_server_sdk_ai_vercel', 'VercelProvider'),
107120
}
108121

109122
if provider_type not in provider_mappings:
Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
"""LangChain implementation of AIProvider for LaunchDarkly AI SDK."""
2+
3+
from typing import Any, Dict, List, Optional
4+
5+
from langchain_core.chat_models import BaseChatModel
6+
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
7+
8+
from ldai.models import AIConfigKind, LDMessage
9+
from ldai.providers.ai_provider import AIProvider
10+
from ldai.providers.types import ChatResponse, LDAIMetrics, StructuredResponse
11+
from ldai.tracker import TokenUsage
12+
13+
14+
class LangChainProvider(AIProvider):
15+
"""
16+
LangChain implementation of AIProvider.
17+
18+
This provider integrates LangChain models with LaunchDarkly's tracking capabilities.
19+
"""
20+
21+
def __init__(self, llm: BaseChatModel, logger: Optional[Any] = None):
22+
"""
23+
Initialize the LangChain provider.
24+
25+
:param llm: LangChain BaseChatModel instance
26+
:param logger: Optional logger for logging provider operations
27+
"""
28+
super().__init__(logger)
29+
self._llm = llm
30+
31+
# =============================================================================
32+
# MAIN FACTORY METHOD
33+
# =============================================================================
34+
35+
@staticmethod
36+
async def create(ai_config: AIConfigKind, logger: Optional[Any] = None) -> 'LangChainProvider':
37+
"""
38+
Static factory method to create a LangChain AIProvider from an AI configuration.
39+
40+
:param ai_config: The LaunchDarkly AI configuration
41+
:param logger: Optional logger for the provider
42+
:return: Configured LangChainProvider instance
43+
"""
44+
llm = await LangChainProvider.create_langchain_model(ai_config)
45+
return LangChainProvider(llm, logger)
46+
47+
# =============================================================================
48+
# INSTANCE METHODS (AIProvider Implementation)
49+
# =============================================================================
50+
51+
async def invoke_model(self, messages: List[LDMessage]) -> ChatResponse:
52+
"""
53+
Invoke the LangChain model with an array of messages.
54+
55+
:param messages: Array of LDMessage objects representing the conversation
56+
:return: ChatResponse containing the model's response
57+
"""
58+
try:
59+
# Convert LDMessage[] to LangChain messages
60+
langchain_messages = LangChainProvider.convert_messages_to_langchain(messages)
61+
62+
# Get the LangChain response
63+
response: AIMessage = await self._llm.ainvoke(langchain_messages)
64+
65+
# Generate metrics early (assumes success by default)
66+
metrics = LangChainProvider.get_ai_metrics_from_response(response)
67+
68+
# Extract text content from the response
69+
content: str = ''
70+
if isinstance(response.content, str):
71+
content = response.content
72+
else:
73+
# Log warning for non-string content (likely multimodal)
74+
if self.logger:
75+
self.logger.warn(
76+
f"Multimodal response not supported, expecting a string. "
77+
f"Content type: {type(response.content)}, Content: {response.content}"
78+
)
79+
# Update metrics to reflect content loss
80+
metrics.success = False
81+
82+
# Create the assistant message
83+
from ldai.models import LDMessage
84+
assistant_message = LDMessage(role='assistant', content=content)
85+
86+
return ChatResponse(
87+
message=assistant_message,
88+
metrics=metrics,
89+
)
90+
except Exception as error:
91+
if self.logger:
92+
self.logger.warn(f'LangChain model invocation failed: {error}')
93+
94+
from ldai.models import LDMessage
95+
return ChatResponse(
96+
message=LDMessage(role='assistant', content=''),
97+
metrics=LDAIMetrics(success=False, usage=None),
98+
)
99+
100+
async def invoke_structured_model(
101+
self,
102+
messages: List[LDMessage],
103+
response_structure: Dict[str, Any],
104+
) -> StructuredResponse:
105+
"""
106+
Invoke the LangChain model with structured output support.
107+
108+
:param messages: Array of LDMessage objects representing the conversation
109+
:param response_structure: Dictionary of output configurations keyed by output name
110+
:return: StructuredResponse containing the structured data
111+
"""
112+
try:
113+
# Convert LDMessage[] to LangChain messages
114+
langchain_messages = LangChainProvider.convert_messages_to_langchain(messages)
115+
116+
# Get the LangChain response with structured output
117+
# Note: with_structured_output is available on BaseChatModel in newer LangChain versions
118+
if hasattr(self._llm, 'with_structured_output'):
119+
structured_llm = self._llm.with_structured_output(response_structure)
120+
response = await structured_llm.ainvoke(langchain_messages)
121+
else:
122+
# Fallback: invoke normally and try to parse as JSON
123+
response_obj = await self._llm.ainvoke(langchain_messages)
124+
if isinstance(response_obj, AIMessage):
125+
import json
126+
try:
127+
response = json.loads(response_obj.content)
128+
except json.JSONDecodeError:
129+
response = {'content': response_obj.content}
130+
else:
131+
response = response_obj
132+
133+
# Using structured output doesn't support metrics
134+
metrics = LDAIMetrics(
135+
success=True,
136+
usage=TokenUsage(total=0, input=0, output=0),
137+
)
138+
139+
import json
140+
return StructuredResponse(
141+
data=response if isinstance(response, dict) else {'result': response},
142+
raw_response=json.dumps(response) if not isinstance(response, str) else response,
143+
metrics=metrics,
144+
)
145+
except Exception as error:
146+
if self.logger:
147+
self.logger.warn(f'LangChain structured model invocation failed: {error}')
148+
149+
return StructuredResponse(
150+
data={},
151+
raw_response='',
152+
metrics=LDAIMetrics(
153+
success=False,
154+
usage=TokenUsage(total=0, input=0, output=0),
155+
),
156+
)
157+
158+
def get_chat_model(self) -> BaseChatModel:
159+
"""
160+
Get the underlying LangChain model instance.
161+
162+
:return: The LangChain BaseChatModel instance
163+
"""
164+
return self._llm
165+
166+
# =============================================================================
167+
# STATIC UTILITY METHODS
168+
# =============================================================================
169+
170+
@staticmethod
171+
def map_provider(ld_provider_name: str) -> str:
172+
"""
173+
Map LaunchDarkly provider names to LangChain provider names.
174+
175+
This method enables seamless integration between LaunchDarkly's standardized
176+
provider naming and LangChain's naming conventions.
177+
178+
:param ld_provider_name: LaunchDarkly provider name
179+
:return: LangChain provider name
180+
"""
181+
lowercased_name = ld_provider_name.lower()
182+
183+
mapping: Dict[str, str] = {
184+
'gemini': 'google-genai',
185+
}
186+
187+
return mapping.get(lowercased_name, lowercased_name)
188+
189+
@staticmethod
190+
def get_ai_metrics_from_response(response: AIMessage) -> LDAIMetrics:
191+
"""
192+
Get AI metrics from a LangChain provider response.
193+
194+
This method extracts token usage information and success status from LangChain responses
195+
and returns a LaunchDarkly LDAIMetrics object.
196+
197+
:param response: The response from the LangChain model
198+
:return: LDAIMetrics with success status and token usage
199+
"""
200+
# Extract token usage if available
201+
usage: Optional[TokenUsage] = None
202+
if hasattr(response, 'response_metadata') and response.response_metadata:
203+
token_usage = response.response_metadata.get('token_usage')
204+
if token_usage:
205+
usage = TokenUsage(
206+
total=token_usage.get('total_tokens', 0) or token_usage.get('totalTokens', 0) or 0,
207+
input=token_usage.get('prompt_tokens', 0) or token_usage.get('promptTokens', 0) or 0,
208+
output=token_usage.get('completion_tokens', 0) or token_usage.get('completionTokens', 0) or 0,
209+
)
210+
211+
# LangChain responses that complete successfully are considered successful by default
212+
return LDAIMetrics(success=True, usage=usage)
213+
214+
@staticmethod
215+
def convert_messages_to_langchain(messages: List[LDMessage]) -> List[BaseMessage]:
216+
"""
217+
Convert LaunchDarkly messages to LangChain messages.
218+
219+
This helper method enables developers to work directly with LangChain message types
220+
while maintaining compatibility with LaunchDarkly's standardized message format.
221+
222+
:param messages: List of LDMessage objects
223+
:return: List of LangChain message objects
224+
"""
225+
result: List[BaseMessage] = []
226+
for msg in messages:
227+
if msg.role == 'system':
228+
result.append(SystemMessage(content=msg.content))
229+
elif msg.role == 'user':
230+
result.append(HumanMessage(content=msg.content))
231+
elif msg.role == 'assistant':
232+
result.append(AIMessage(content=msg.content))
233+
else:
234+
raise ValueError(f'Unsupported message role: {msg.role}')
235+
return result
236+
237+
@staticmethod
238+
async def create_langchain_model(ai_config: AIConfigKind) -> BaseChatModel:
239+
"""
240+
Create a LangChain model from an AI configuration.
241+
242+
This public helper method enables developers to initialize their own LangChain models
243+
using LaunchDarkly AI configurations.
244+
245+
:param ai_config: The LaunchDarkly AI configuration
246+
:return: A configured LangChain BaseChatModel
247+
"""
248+
model_name = ai_config.model.name if ai_config.model else ''
249+
provider = ai_config.provider.name if ai_config.provider else ''
250+
parameters = ai_config.model.get_parameter('parameters') if ai_config.model else {}
251+
if not isinstance(parameters, dict):
252+
parameters = {}
253+
254+
# Use LangChain's init_chat_model to support multiple providers
255+
# Note: This requires langchain package to be installed
256+
try:
257+
# Try to import init_chat_model from langchain.chat_models
258+
# This is available in langchain >= 0.1.0
259+
try:
260+
from langchain.chat_models import init_chat_model
261+
except ImportError:
262+
# Fallback for older versions or different import path
263+
from langchain.chat_models.universal import init_chat_model
264+
265+
# Map provider name
266+
langchain_provider = LangChainProvider.map_provider(provider)
267+
268+
# Create model configuration
269+
model_kwargs = {**parameters}
270+
if langchain_provider:
271+
model_kwargs['model_provider'] = langchain_provider
272+
273+
# Initialize the chat model (init_chat_model may be async or sync)
274+
result = init_chat_model(model_name, **model_kwargs)
275+
# Handle both sync and async initialization
276+
if hasattr(result, '__await__'):
277+
return await result
278+
return result
279+
except ImportError as e:
280+
raise ImportError(
281+
'langchain package is required for LangChainProvider. '
282+
'Install it with: pip install langchain langchain-core'
283+
) from e
284+

0 commit comments

Comments
 (0)