|
5 | 5 | from typing import List, Dict, Optional, Any |
6 | 6 | from label_studio_ml.model import LabelStudioMLBase |
7 | 7 |
|
8 | | -# Try new langchain imports first (v0.1.0+), fall back to old imports |
9 | | -try: |
10 | | - from langchain_community.utilities import GoogleSearchAPIWrapper |
11 | | - from langchain_core.callbacks import BaseCallbackHandler |
12 | | - from langchain.agents import initialize_agent, AgentType |
13 | | - from langchain_openai import OpenAI |
14 | | - from langchain.tools import Tool |
15 | | -except ImportError: |
16 | | - # Fall back to old imports for older langchain versions |
17 | | - from langchain.utilities import GoogleSearchAPIWrapper |
18 | | - from langchain.callbacks.base import BaseCallbackHandler |
19 | | - from langchain.agents import initialize_agent, AgentType |
20 | | - from langchain.llms import OpenAI |
21 | | - from langchain.tools import Tool |
| 8 | + |
| 9 | +# Import langchain components - use new API (v1.0+) |
| 10 | +from langchain_community.utilities import GoogleSearchAPIWrapper |
| 11 | +from langchain_core.callbacks import BaseCallbackHandler |
| 12 | +from langchain.agents import create_agent |
| 13 | +from langchain_openai import ChatOpenAI |
| 14 | +from langchain_core.tools import Tool |
| 15 | + |
22 | 16 |
|
23 | 17 | from label_studio_ml.utils import match_labels |
24 | 18 |
|
@@ -92,17 +86,16 @@ def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) - |
92 | 86 | func=search.run, |
93 | 87 | callbacks=[search_results] |
94 | 88 | )] |
95 | | - llm = OpenAI( |
| 89 | + llm = ChatOpenAI( |
96 | 90 | temperature=0, |
97 | | - model_name='gpt-3.5-turbo-instruct' |
| 91 | + model="gpt-3.5-turbo" |
98 | 92 | ) |
99 | | - agent = initialize_agent( |
100 | | - tools, |
101 | | - llm, |
102 | | - agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, |
103 | | - verbose=True, |
104 | | - max_iterations=3, |
105 | | - early_stopping_method="generate", |
| 93 | + |
| 94 | + # Use new agent API (langchain 1.0+) |
| 95 | + agent = create_agent( |
| 96 | + model=llm, |
| 97 | + tools=tools, |
| 98 | + debug=True |
106 | 99 | ) |
107 | 100 |
|
108 | 101 | labels = self.parsed_label_config[from_name]['labels'] |
@@ -131,7 +124,24 @@ def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) - |
131 | 124 | text = self.preload_task_data(task, task['data'][value]) |
132 | 125 | full_prompt = self.PROMPT_TEMPLATE.format(prompt=prompt, text=text) |
133 | 126 | logger.info(f'Full prompt: {full_prompt}') |
134 | | - llm_result = agent.run(full_prompt) |
| 127 | + # Invoke the agent with the prompt |
| 128 | + result = agent.invoke({"messages": [("user", full_prompt)]}) |
| 129 | + # Extract the response from the agent result |
| 130 | + if isinstance(result, dict) and "messages" in result: |
| 131 | + # Get the last message which should be the agent's response |
| 132 | + messages = result["messages"] |
| 133 | + if messages: |
| 134 | + last_message = messages[-1] |
| 135 | + if hasattr(last_message, 'content'): |
| 136 | + llm_result = last_message.content |
| 137 | + elif isinstance(last_message, dict) and 'content' in last_message: |
| 138 | + llm_result = last_message['content'] |
| 139 | + else: |
| 140 | + llm_result = str(last_message) |
| 141 | + else: |
| 142 | + llm_result = str(result) |
| 143 | + else: |
| 144 | + llm_result = str(result) |
135 | 145 | output_classes = match_labels(llm_result, labels) |
136 | 146 | snippets = search_results.snippets |
137 | 147 | logger.debug(f'LLM result: {llm_result}') |
|
0 commit comments