Skip to content

Commit 378a1f2

Browse files
authored
fix(openrons-ai-server, opentrons-ai-client): predict method (#16967)
<!-- Thanks for taking the time to open a Pull Request (PR)! Please make sure you've read the "Opening Pull Requests" section of our Contributing Guide: https://github.com/Opentrons/opentrons/blob/edge/CONTRIBUTING.md#opening-pull-requests GitHub provides robust markdown to format your PR. Links, diagrams, pictures, and videos along with text formatting make it possible to create a rich and informative PR. For more information on GitHub markdown, see: https://docs.github.com/en/get-started/writing-on-github/getting-started-with-writing-and-formatting-on-github/basic-writing-and-formatting-syntax To ensure your code is reviewed quickly and thoroughly, please fill out the sections below to the best of your ability! --> # Overview Two changed - **Backend**: Previously, two endpoints end up using the same model, therefore confused during message processing. Present, two endpoints use separate models. - **Frontend**: Flex gripper is concerned with only Flex robot - opentrons-ai-client/src/resources/utils/createProtocolUtils.tsx Closes AUTH-1076 <!-- Describe your PR at a high level. State acceptance criteria and how this PR fits into other work. Link issues, PRs, and other relevant resources. --> ## Test Plan and Hands on Testing - Visit `opentrons.ai` - Click `Update an existing protocol` and follow instructions - Click `Create a new protocol` and follow instructions. Once you complete providing labware and other information. Click Submit then it will take you chat window where you need to see the generated protocol. It should not start like 'Simulation is successful'. <!-- Describe your testing of the PR. Emphasize testing not reflected in the code. Attach protocols, logs, screenshots and any other assets that support your testing. --> ## Changelog <!-- List changes introduced by this PR considering future developers and the end user. Give careful thought and clear documentation to breaking changes. --> ## Review requests All tests are passing <!-- - What do you need from reviewers to feel confident this PR is ready to merge? - Ask questions. --> ## Risk assessment Low <!-- - Indicate the level of attention this PR needs. - Provide context to guide reviewers. - Discuss trade-offs, coupling, and side effects. - Look for the possibility, even if you think it's small, that your change may affect some other part of the system. - For instance, changing return tip behavior may also change the behavior of labware calibration. - How do your unit tests and on hands on testing mitigate this PR's risks and the risk of future regressions? - Especially in high risk PRs, explain how you know your testing is enough. -->
1 parent b6e29e9 commit 378a1f2

File tree

8 files changed

+701
-237
lines changed

8 files changed

+701
-237
lines changed

opentrons-ai-client/src/resources/utils/createProtocolUtils.tsx

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,8 @@ export function generateChatPrompt(
182182
.join('\n')
183183
: `- ${t(values.instruments.pipettes)}`
184184
const flexGripper =
185-
values.instruments.flexGripper === FLEX_GRIPPER
185+
values.instruments.flexGripper === FLEX_GRIPPER &&
186+
values.instruments.robot === OPENTRONS_FLEX
186187
? `\n- ${t('with_flex_gripper')}`
187188
: ''
188189
const modules = values.modules

opentrons-ai-server/api/domain/anthropic_predict.py

Lines changed: 41 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import uuid
22
from pathlib import Path
3-
from typing import Any, Dict, List
3+
from typing import Any, Dict, List, Literal
44

55
import requests
66
import structlog
@@ -23,7 +23,7 @@ def __init__(self, settings: Settings) -> None:
2323
self.model_name: str = settings.anthropic_model_name
2424
self.system_prompt: str = SYSTEM_PROMPT
2525
self.path_docs: Path = ROOT_PATH / "api" / "storage" / "docs"
26-
self._messages: List[MessageParam] = [
26+
self.cached_docs: List[MessageParam] = [
2727
{
2828
"role": "user",
2929
"content": [
@@ -77,19 +77,26 @@ def get_docs(self) -> str:
7777
return "\n".join(xml_output)
7878

7979
@tracer.wrap()
80-
def generate_message(self, max_tokens: int = 4096) -> Message:
80+
def _process_message(
81+
self, user_id: str, messages: List[MessageParam], message_type: Literal["create", "update"], max_tokens: int = 4096
82+
) -> Message:
83+
"""
84+
Internal method to handle message processing with different system prompts.
85+
For now, system prompt is the same.
86+
"""
8187

82-
response = self.client.messages.create(
88+
response: Message = self.client.messages.create(
8389
model=self.model_name,
8490
system=self.system_prompt,
8591
max_tokens=max_tokens,
86-
messages=self._messages,
92+
messages=messages,
8793
tools=self.tools, # type: ignore
8894
extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"},
95+
metadata={"user_id": user_id},
8996
)
9097

9198
logger.info(
92-
"Token usage",
99+
f"Token usage: {message_type.capitalize()}",
93100
extra={
94101
"input_tokens": response.usage.input_tokens,
95102
"output_tokens": response.usage.output_tokens,
@@ -100,15 +107,23 @@ def generate_message(self, max_tokens: int = 4096) -> Message:
100107
return response
101108

102109
@tracer.wrap()
103-
def predict(self, prompt: str) -> str | None:
110+
def process_message(
111+
self, user_id: str, prompt: str, history: List[MessageParam] | None = None, message_type: Literal["create", "update"] = "create"
112+
) -> str | None:
113+
"""Unified method for creating and updating messages"""
104114
try:
105-
self._messages.append({"role": "user", "content": PROMPT.format(USER_PROMPT=prompt)})
106-
response = self.generate_message()
115+
messages: List[MessageParam] = self.cached_docs.copy()
116+
if history:
117+
messages += history
118+
119+
messages.append({"role": "user", "content": PROMPT.format(USER_PROMPT=prompt)})
120+
response = self._process_message(user_id=user_id, messages=messages, message_type=message_type)
121+
107122
if response.content[-1].type == "tool_use":
108123
tool_use = response.content[-1]
109-
self._messages.append({"role": "assistant", "content": response.content})
124+
messages.append({"role": "assistant", "content": response.content})
110125
result = self.handle_tool_use(tool_use.name, tool_use.input) # type: ignore
111-
self._messages.append(
126+
messages.append(
112127
{
113128
"role": "user",
114129
"content": [
@@ -120,25 +135,26 @@ def predict(self, prompt: str) -> str | None:
120135
],
121136
}
122137
)
123-
follow_up = self.generate_message()
124-
response_text = follow_up.content[0].text # type: ignore
125-
self._messages.append({"role": "assistant", "content": response_text})
126-
return response_text
138+
follow_up = self._process_message(user_id=user_id, messages=messages, message_type=message_type)
139+
return follow_up.content[0].text # type: ignore
127140

128141
elif response.content[0].type == "text":
129-
response_text = response.content[0].text
130-
self._messages.append({"role": "assistant", "content": response_text})
131-
return response_text
142+
return response.content[0].text
132143

133144
logger.error("Unexpected response type")
134145
return None
135-
except IndexError as e:
136-
logger.error("Invalid response format", extra={"error": str(e)})
137-
return None
138146
except Exception as e:
139-
logger.error("Error in predict method", extra={"error": str(e)})
147+
logger.error(f"Error in {message_type} method", extra={"error": str(e)})
140148
return None
141149

150+
@tracer.wrap()
151+
def create(self, user_id: str, prompt: str, history: List[MessageParam] | None = None) -> str | None:
152+
return self.process_message(user_id, prompt, history, "create")
153+
154+
@tracer.wrap()
155+
def update(self, user_id: str, prompt: str, history: List[MessageParam] | None = None) -> str | None:
156+
return self.process_message(user_id, prompt, history, "update")
157+
142158
@tracer.wrap()
143159
def handle_tool_use(self, func_name: str, func_params: Dict[str, Any]) -> str:
144160
if func_name == "simulate_protocol":
@@ -148,17 +164,6 @@ def handle_tool_use(self, func_name: str, func_params: Dict[str, Any]) -> str:
148164
logger.error("Unknown tool", extra={"tool": func_name})
149165
raise ValueError(f"Unknown tool: {func_name}")
150166

151-
@tracer.wrap()
152-
def reset(self) -> None:
153-
self._messages = [
154-
{
155-
"role": "user",
156-
"content": [
157-
{"type": "text", "text": DOCUMENTS.format(doc_content=self.get_docs()), "cache_control": {"type": "ephemeral"}} # type: ignore
158-
],
159-
}
160-
]
161-
162167
@tracer.wrap()
163168
def simulate_protocol(self, protocol: str) -> str:
164169
url = "https://Opentrons-simulator.hf.space/protocol"
@@ -197,8 +202,9 @@ def main() -> None:
197202

198203
settings = Settings()
199204
llm = AnthropicPredict(settings)
200-
prompt = Prompt.ask("Type a prompt to send to the Anthropic API:")
201-
completion = llm.predict(prompt)
205+
Prompt.ask("Type a prompt to send to the Anthropic API:")
206+
207+
completion = llm.create(user_id="1", prompt="hi", history=None)
202208
print(completion)
203209

204210

opentrons-ai-server/api/domain/config_anthropic.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,11 @@
44
55
Your key responsibilities:
66
1. Welcome scientists warmly and understand their protocol needs
7-
2. Generate accurate Python protocols using standard Opentrons labware
7+
2. Generate accurate Python protocols using standard Opentrons labware (see <source> standard-loadname-info.md </source> in <document>)
88
3. Provide clear explanations and documentation
99
4. Flag potential safety or compatibility issues
1010
5. Suggest protocol optimizations when appropriate
1111
12-
Call protocol simulation tool to validate the code - only when it is called explicitly by the user.
13-
For all other queries, provide direct responses.
14-
1512
Important guidelines:
1613
- Always verify labware compatibility before generating protocols
1714
- Include appropriate error handling in generated code
@@ -28,26 +25,25 @@
2825
"""
2926

3027
PROMPT = """
31-
Here are the inputs you will work with:
32-
33-
<user_prompt>
34-
{USER_PROMPT}
35-
</user_prompt>
36-
3728
3829
Follow these instructions to handle the user's prompt:
3930
40-
1. Analyze the user's prompt to determine if it's:
31+
1. <Analyze the user's prompt to determine if it's>:
4132
a) A request to generate a protocol
42-
b) A question about the Opentrons Python API v2
33+
b) A question about the Opentrons Python API v2 or about details of protocol
4334
c) A common task (e.g., value changes, OT-2 to Flex conversion, slot correction)
4435
d) An unrelated or unclear request
36+
e) A tool calling. If a user calls simulate protocol explicity, then call.
37+
f) A greeting. Respond kindly.
4538
46-
2. If the prompt is unrelated or unclear, ask the user for clarification. For example:
47-
I apologize, but your prompt seems unclear. Could you please provide more details?
39+
Note: when you respond you dont need mention the category or the type.
4840
41+
2. If the prompt is unrelated or unclear, ask the user for clarification.
42+
I'm sorry, but your prompt seems unclear. Could you please provide more details?
43+
You dont need to mention
4944
50-
3. If the prompt is a question about the API, answer it using only the information
45+
46+
3. If the prompt is a question about the API or details, answer it using only the information
5147
provided in the <document></document> section. Provide references and place them under the <References> tag.
5248
Format your response like this:
5349
API answer:
@@ -86,8 +82,8 @@
8682
}}
8783
8884
requirements = {{
89-
'robotType': '[Robot type based on user prompt, OT-2 or Flex, default is OT-2]',
90-
'apiLevel': '[apiLevel, default is 2.19 ]'
85+
'robotType': '[Robot type: OT-2(default) for Opentrons OT-2, Flex for Opentrons Flex]',
86+
'apiLevel': '[apiLevel, default: 2.19]'
9187
}}
9288
9389
def run(protocol: protocol_api.ProtocolContext):
@@ -214,4 +210,10 @@ def run(protocol: protocol_api.ProtocolContext):
214210
as a reference to generate a basic protocol.
215211
216212
Remember to use only the information provided in the <document></document>. Do not introduce any external information or assumptions.
213+
214+
Here are the inputs you will work with:
215+
216+
<user_prompt>
217+
{USER_PROMPT}
218+
</user_prompt>
217219
"""

opentrons-ai-server/api/handler/fast.py

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -199,10 +199,19 @@ async def create_chat_completion(
199199
return ChatResponse(reply="Default fake response. ", fake=body.fake)
200200

201201
response: Optional[str] = None
202+
203+
if body.history and body.history[0].get("content") and "Write a protocol using" in body.history[0]["content"]: # type: ignore
204+
protocol_option = "create"
205+
else:
206+
protocol_option = "update"
207+
202208
if "openai" in settings.model.lower():
203209
response = openai.predict(prompt=body.message, chat_completion_message_params=body.history)
204210
else:
205-
response = claude.predict(prompt=body.message)
211+
if protocol_option == "create":
212+
response = claude.create(user_id=str(user.sub), prompt=body.message, history=body.history) # type: ignore
213+
else:
214+
response = claude.update(user_id=str(user.sub), prompt=body.message, history=body.history) # type: ignore
206215

207216
if response is None or response == "":
208217
return ChatResponse(reply="No response was generated", fake=bool(body.fake))
@@ -218,88 +227,88 @@ async def create_chat_completion(
218227

219228
@tracer.wrap()
220229
@app.post(
221-
"/api/chat/updateProtocol",
230+
"/api/chat/createProtocol",
222231
response_model=Union[ChatResponse, ErrorResponse],
223-
summary="Updates protocol",
224-
description="Generate a chat response based on the provided prompt that will update an existing protocol with the required changes.",
232+
summary="Creates protocol",
233+
description="Generate a chat response based on the provided prompt that will create a new protocol with the required changes.",
225234
)
226-
async def update_protocol(
227-
body: UpdateProtocol, user: Annotated[User, Security(auth.verify)]
235+
async def create_protocol(
236+
body: CreateProtocol, user: Annotated[User, Security(auth.verify)]
228237
) -> Union[ChatResponse, ErrorResponse]: # noqa: B008
229238
"""
230239
Generate an updated protocol using LLM.
231240
232-
- **request**: The HTTP request containing the existing protocol and other relevant parameters.
241+
- **request**: The HTTP request containing the chat message.
233242
- **returns**: A chat response or an error message.
234243
"""
235-
logger.info("POST /api/chat/updateProtocol", extra={"body": body.model_dump(), "user": user})
244+
logger.info("POST /api/chat/createProtocol", extra={"body": body.model_dump(), "user": user})
236245
try:
237-
if not body.protocol_text or body.protocol_text == "":
246+
247+
if not body.prompt or body.prompt == "":
238248
raise HTTPException(
239249
status_code=status.HTTP_400_BAD_REQUEST, detail=EmptyRequestError(message="Request body is empty").model_dump()
240250
)
241251

242252
if body.fake:
243-
return ChatResponse(reply="Fake response", fake=bool(body.fake))
253+
return ChatResponse(reply="Fake response", fake=body.fake)
244254

245255
response: Optional[str] = None
246256
if "openai" in settings.model.lower():
247-
response = openai.predict(prompt=body.prompt, chat_completion_message_params=None)
257+
response = openai.predict(prompt=str(body.model_dump()), chat_completion_message_params=None)
248258
else:
249-
response = claude.predict(prompt=body.prompt)
259+
response = claude.create(user_id=str(user.sub), prompt=body.prompt, history=None)
250260

251261
if response is None or response == "":
252262
return ChatResponse(reply="No response was generated", fake=bool(body.fake))
253263

254264
return ChatResponse(reply=response, fake=bool(body.fake))
255265

256266
except Exception as e:
257-
logger.exception("Error processing protocol update")
267+
logger.exception("Error processing protocol creation")
258268
raise HTTPException(
259269
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=InternalServerError(exception_object=e).model_dump()
260270
) from e
261271

262272

263273
@tracer.wrap()
264274
@app.post(
265-
"/api/chat/createProtocol",
275+
"/api/chat/updateProtocol",
266276
response_model=Union[ChatResponse, ErrorResponse],
267-
summary="Creates protocol",
268-
description="Generate a chat response based on the provided prompt that will create a new protocol with the required changes.",
277+
summary="Updates protocol",
278+
description="Generate a chat response based on the provided prompt that will update an existing protocol with the required changes.",
269279
)
270-
async def create_protocol(
271-
body: CreateProtocol, user: Annotated[User, Security(auth.verify)]
280+
async def update_protocol(
281+
body: UpdateProtocol, user: Annotated[User, Security(auth.verify)]
272282
) -> Union[ChatResponse, ErrorResponse]: # noqa: B008
273283
"""
274284
Generate an updated protocol using LLM.
275285
276-
- **request**: The HTTP request containing the chat message.
286+
- **request**: The HTTP request containing the existing protocol and other relevant parameters.
277287
- **returns**: A chat response or an error message.
278288
"""
279-
logger.info("POST /api/chat/createProtocol", extra={"body": body.model_dump(), "user": user})
289+
logger.info("POST /api/chat/updateProtocol", extra={"body": body.model_dump(), "user": user})
280290
try:
281-
282-
if not body.prompt or body.prompt == "":
291+
if not body.protocol_text or body.protocol_text == "":
283292
raise HTTPException(
284293
status_code=status.HTTP_400_BAD_REQUEST, detail=EmptyRequestError(message="Request body is empty").model_dump()
285294
)
286295

287296
if body.fake:
288-
return ChatResponse(reply="Fake response", fake=body.fake)
297+
return ChatResponse(reply="Fake response", fake=bool(body.fake))
289298

290299
response: Optional[str] = None
291300
if "openai" in settings.model.lower():
292-
response = openai.predict(prompt=str(body.model_dump()), chat_completion_message_params=None)
301+
response = openai.predict(prompt=body.prompt, chat_completion_message_params=None)
293302
else:
294-
response = claude.predict(prompt=str(body.model_dump()))
303+
response = claude.update(user_id=str(user.sub), prompt=body.prompt, history=None)
295304

296305
if response is None or response == "":
297306
return ChatResponse(reply="No response was generated", fake=bool(body.fake))
298307

299308
return ChatResponse(reply=response, fake=bool(body.fake))
300309

301310
except Exception as e:
302-
logger.exception("Error processing protocol creation")
311+
logger.exception("Error processing protocol update")
303312
raise HTTPException(
304313
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=InternalServerError(exception_object=e).model_dump()
305314
) from e

opentrons-ai-server/api/models/chat_request.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,13 @@ class Chat(BaseModel):
2424
Field(None, description="Chat history in the form of a list of messages. Type is from OpenAI's ChatCompletionMessageParam"),
2525
]
2626

27+
ChatOptions = Literal["update", "create"]
28+
ChatOptionType = Annotated[Optional[ChatOptions], Field("create", description="which chat pathway did the user enter: create or update")]
29+
2730

2831
class ChatRequest(BaseModel):
2932
message: str = Field(..., description="The latest message to be processed.")
3033
history: HistoryType
3134
fake: bool = Field(True, description="When set to true, the response will be a fake. OpenAI API is not used.")
3235
fake_key: FakeKeyType
36+
chat_options: ChatOptionType

0 commit comments

Comments
 (0)