Skip to content

Commit a8c52c5

Browse files
authored
[2.x] Fix Amazon Nova support (use StrOutputParser) (#1203)
* use StrOutputParser in default chat * encourage using StrOutputParser in docs * pre-commit * use StrOutputParser in /fix
1 parent 92262ba commit a8c52c5

File tree

3 files changed

+5
-3
lines changed

3 files changed

+5
-3
lines changed

docs/source/developers/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ def create_llm_chain(
492492
prompt_template = FIX_PROMPT_TEMPLATE
493493
self.prompt_template = prompt_template
494494

495-
runnable = prompt_template | llm # type:ignore
495+
runnable = prompt_template | llm | StrOutputParser() # type:ignore
496496
self.llm_chain = runnable
497497
```
498498

packages/jupyter-ai/jupyter_ai/chat_handlers/default.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from jupyter_ai.models import HumanChatMessage
55
from jupyter_ai_magics.providers import BaseProvider
6+
from langchain_core.output_parsers import StrOutputParser
67
from langchain_core.runnables import ConfigurableFieldSpec
78
from langchain_core.runnables.history import RunnableWithMessageHistory
89

@@ -37,7 +38,7 @@ def create_llm_chain(
3738
self.llm = llm
3839
self.prompt_template = prompt_template
3940

40-
runnable = prompt_template | llm # type:ignore
41+
runnable = prompt_template | llm | StrOutputParser() # type:ignore
4142
if not llm.manages_history:
4243
runnable = RunnableWithMessageHistory(
4344
runnable=runnable, # type:ignore[arg-type]

packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from jupyter_ai.models import CellWithErrorSelection, HumanChatMessage
44
from jupyter_ai_magics.providers import BaseProvider
55
from langchain.prompts import PromptTemplate
6+
from langchain_core.output_parsers import StrOutputParser
67

78
from .base import BaseChatHandler, SlashCommandRoutingType
89

@@ -76,7 +77,7 @@ def create_llm_chain(
7677
self.llm = llm
7778
prompt_template = FIX_PROMPT_TEMPLATE
7879

79-
runnable = prompt_template | llm # type:ignore
80+
runnable = prompt_template | llm | StrOutputParser() # type:ignore
8081
self.llm_chain = runnable
8182

8283
async def process_message(self, message: HumanChatMessage):

0 commit comments

Comments
 (0)