Skip to content

Commit d7df9e8

Browse files
System instruction override
1 parent abef33c commit d7df9e8

File tree

7 files changed

+141
-36
lines changed

7 files changed

+141
-36
lines changed

src/neo4j_graphrag/llm/anthropic_llm.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,22 +84,31 @@ def get_messages(
8484
return messages
8585

8686
def invoke(
87-
self, input: str, message_history: Optional[list[BaseMessage]] = None
87+
self,
88+
input: str,
89+
message_history: Optional[list[BaseMessage]] = None,
90+
system_instruction: Optional[str] = None,
8891
) -> LLMResponse:
8992
"""Sends text to the LLM and returns a response.
9093
9194
Args:
9295
input (str): The text to send to the LLM.
9396
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
97+
system_instruction (Optional[str]): An option to override the llm system message for this invokation.
9498
9599
Returns:
96100
LLMResponse: The response from the LLM.
97101
"""
98102
try:
99103
messages = self.get_messages(input, message_history)
104+
system_message = (
105+
system_instruction
106+
if system_instruction is not None
107+
else self.system_instruction
108+
)
100109
response = self.client.messages.create(
101110
model=self.model_name,
102-
system=self.system_instruction,
111+
system=system_message,
103112
messages=messages,
104113
**self.model_params,
105114
)
@@ -108,22 +117,31 @@ def invoke(
108117
raise LLMGenerationError(e)
109118

110119
async def ainvoke(
111-
self, input: str, message_history: Optional[list[BaseMessage]] = None
120+
self,
121+
input: str,
122+
message_history: Optional[list[BaseMessage]] = None,
123+
system_instruction: Optional[str] = None,
112124
) -> LLMResponse:
113125
"""Asynchronously sends text to the LLM and returns a response.
114126
115127
Args:
116128
input (str): The text to send to the LLM.
117129
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
130+
system_instruction (Optional[str]): An option to override the llm system message for this invokation.
118131
119132
Returns:
120133
LLMResponse: The response from the LLM.
121134
"""
122135
try:
123136
messages = self.get_messages(input, message_history)
137+
system_message = (
138+
system_instruction
139+
if system_instruction is not None
140+
else self.system_instruction
141+
)
124142
response = await self.async_client.messages.create(
125143
model=self.model_name,
126-
system=self.system_instruction,
144+
system=system_message,
127145
messages=messages,
128146
**self.model_params,
129147
)

src/neo4j_graphrag/llm/base.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,17 @@ def invoke(
6464

6565
@abstractmethod
6666
async def ainvoke(
67-
self, input: str, message_history: Optional[list[dict[str, str]]] = None
67+
self,
68+
input: str,
69+
message_history: Optional[list[BaseMessage]] = None,
70+
system_instruction: Optional[str] = None,
6871
) -> LLMResponse:
6972
"""Asynchronously sends a text input to the LLM and retrieves a response.
7073
7174
Args:
7275
input (str): Text sent to the LLM.
7376
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
74-
77+
system_instruction (Optional[str]): An option to override the llm system message for this invokation.
7578
7679
Returns:
7780
LLMResponse: The response from the LLM.

src/neo4j_graphrag/llm/cohere_llm.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,19 @@ def __init__(
7575
self.async_client = cohere.AsyncClientV2(**kwargs)
7676

7777
def get_messages(
78-
self, input: str, message_history: Optional[list[BaseMessage]] = None
78+
self,
79+
input: str,
80+
message_history: Optional[list[BaseMessage]] = None,
81+
system_instruction: Optional[str] = None,
7982
) -> ChatMessages:
8083
messages = []
81-
if self.system_instruction:
82-
messages.append(SystemMessage(content=self.system_instruction).model_dump())
84+
system_message = (
85+
system_instruction
86+
if system_instruction is not None
87+
else self.system_instruction
88+
)
89+
if system_message:
90+
messages.append(SystemMessage(content=system_message).model_dump())
8391
if message_history:
8492
try:
8593
MessageList(messages=message_history)
@@ -90,19 +98,23 @@ def get_messages(
9098
return messages
9199

92100
def invoke(
93-
self, input: str, message_history: Optional[list[BaseMessage]] = None
101+
self,
102+
input: str,
103+
message_history: Optional[list[BaseMessage]] = None,
104+
system_instruction: Optional[str] = None,
94105
) -> LLMResponse:
95106
"""Sends text to the LLM and returns a response.
96107
97108
Args:
98109
input (str): The text to send to the LLM.
99110
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
111+
system_instruction (Optional[str]): An option to override the llm system message for this invokation.
100112
101113
Returns:
102114
LLMResponse: The response from the LLM.
103115
"""
104116
try:
105-
messages = self.get_messages(input, message_history)
117+
messages = self.get_messages(input, message_history, system_instruction)
106118
res = self.client.chat(
107119
messages=messages,
108120
model=self.model_name,
@@ -114,19 +126,23 @@ def invoke(
114126
)
115127

116128
async def ainvoke(
117-
self, input: str, message_history: Optional[list[BaseMessage]] = None
129+
self,
130+
input: str,
131+
message_history: Optional[list[BaseMessage]] = None,
132+
system_instruction: Optional[str] = None,
118133
) -> LLMResponse:
119134
"""Asynchronously sends text to the LLM and returns a response.
120135
121136
Args:
122137
input (str): The text to send to the LLM.
123138
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
139+
system_instruction (Optional[str]): An option to override the llm system message for this invokation.
124140
125141
Returns:
126142
LLMResponse: The response from the LLM.
127143
"""
128144
try:
129-
messages = self.get_messages(input, message_history)
145+
messages = self.get_messages(input, message_history, system_instruction)
130146
res = self.async_client.chat(
131147
messages=messages,
132148
model=self.model_name,

src/neo4j_graphrag/llm/mistralai_llm.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,19 @@ def __init__(
6666
self.client = Mistral(api_key=api_key, **kwargs)
6767

6868
def get_messages(
69-
self, input: str, message_history: Optional[list[BaseMessage]] = None
69+
self,
70+
input: str,
71+
message_history: Optional[list[BaseMessage]] = None,
72+
system_instruction: Optional[str] = None,
7073
) -> list[Messages]:
7174
messages = []
72-
if self.system_instruction:
73-
messages.append(SystemMessage(content=self.system_instruction).model_dump())
75+
system_message = (
76+
system_instruction
77+
if system_instruction is not None
78+
else self.system_instruction
79+
)
80+
if system_message:
81+
messages.append(SystemMessage(content=system_message).model_dump())
7482
if message_history:
7583
try:
7684
MessageList(messages=message_history)
@@ -81,14 +89,18 @@ def get_messages(
8189
return messages
8290

8391
def invoke(
84-
self, input: str, message_history: Optional[list[BaseMessage]] = None
92+
self,
93+
input: str,
94+
message_history: Optional[list[BaseMessage]] = None,
95+
system_instruction: Optional[str] = None,
8596
) -> LLMResponse:
8697
"""Sends a text input to the Mistral chat completion model
8798
and returns the response's content.
8899
89100
Args:
90101
input (str): Text sent to the LLM.
91102
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
103+
system_instruction (Optional[str]): An option to override the llm system message for this invokation.
92104
93105
Returns:
94106
LLMResponse: The response from MistralAI.
@@ -97,7 +109,7 @@ def invoke(
97109
LLMGenerationError: If anything goes wrong.
98110
"""
99111
try:
100-
messages = self.get_messages(input, message_history)
112+
messages = self.get_messages(input, message_history, system_instruction)
101113
response = self.client.chat.complete(
102114
model=self.model_name,
103115
messages=messages,
@@ -113,14 +125,18 @@ def invoke(
113125
raise LLMGenerationError(e)
114126

115127
async def ainvoke(
116-
self, input: str, message_history: Optional[list[BaseMessage]] = None
128+
self,
129+
input: str,
130+
message_history: Optional[list[BaseMessage]] = None,
131+
system_instruction: Optional[str] = None,
117132
) -> LLMResponse:
118133
"""Asynchronously sends a text input to the MistralAI chat
119134
completion model and returns the response's content.
120135
121136
Args:
122137
input (str): Text sent to the LLM.
123138
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
139+
system_instruction (Optional[str]): An option to override the llm system message for this invokation.
124140
125141
Returns:
126142
LLMResponse: The response from MistralAI.
@@ -129,7 +145,7 @@ async def ainvoke(
129145
LLMGenerationError: If anything goes wrong.
130146
"""
131147
try:
132-
messages = self.get_messages(input, message_history)
148+
messages = self.get_messages(input, message_history, system_instruction)
133149
response = await self.client.chat.complete_async(
134150
model=self.model_name,
135151
messages=messages,

src/neo4j_graphrag/llm/ollama_llm.py

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,19 @@ def __init__(
5050
)
5151

5252
def get_messages(
53-
self, input: str, message_history: Optional[list[BaseMessage]] = None
53+
self,
54+
input: str,
55+
message_history: Optional[list[BaseMessage]] = None,
56+
system_instruction: Optional[str] = None,
5457
) -> Sequence[Message]:
5558
messages = []
56-
if self.system_instruction:
57-
messages.append(SystemMessage(content=self.system_instruction).model_dump())
59+
system_message = (
60+
system_instruction
61+
if system_instruction is not None
62+
else self.system_instruction
63+
)
64+
if system_message:
65+
messages.append(SystemMessage(content=system_message).model_dump())
5866
if message_history:
5967
try:
6068
MessageList(messages=message_history)
@@ -65,12 +73,25 @@ def get_messages(
6573
return messages
6674

6775
def invoke(
68-
self, input: str, message_history: Optional[list[BaseMessage]] = None
76+
self,
77+
input: str,
78+
message_history: Optional[list[BaseMessage]] = None,
79+
system_instruction: Optional[str] = None,
6980
) -> LLMResponse:
81+
"""Sends text to the LLM and returns a response.
82+
83+
Args:
84+
input (str): The text to send to the LLM.
85+
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
86+
system_instruction (Optional[str]): An option to override the llm system message for this invokation.
87+
88+
Returns:
89+
LLMResponse: The response from the LLM.
90+
"""
7091
try:
7192
response = self.client.chat(
7293
model=self.model_name,
73-
messages=self.get_messages(input, message_history),
94+
messages=self.get_messages(input, message_history, system_instruction),
7495
options=self.model_params,
7596
)
7697
content = response.message.content or ""
@@ -79,12 +100,29 @@ def invoke(
79100
raise LLMGenerationError(e)
80101

81102
async def ainvoke(
82-
self, input: str, message_history: Optional[list[BaseMessage]] = None
103+
self,
104+
input: str,
105+
message_history: Optional[list[BaseMessage]] = None,
106+
system_instruction: Optional[str] = None,
83107
) -> LLMResponse:
108+
"""Asynchronously sends a text input to the OpenAI chat
109+
completion model and returns the response's content.
110+
111+
Args:
112+
input (str): Text sent to the LLM.
113+
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
114+
system_instruction (Optional[str]): An option to override the llm system message for this invokation.
115+
116+
Returns:
117+
LLMResponse: The response from OpenAI.
118+
119+
Raises:
120+
LLMGenerationError: If anything goes wrong.
121+
"""
84122
try:
85123
response = await self.async_client.chat(
86124
model=self.model_name,
87-
messages=self.get_messages(input, message_history),
125+
messages=self.get_messages(input, message_history, system_instruction),
88126
options=self.model_params,
89127
)
90128
content = response.message.content or ""

src/neo4j_graphrag/llm/openai_llm.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,14 +115,18 @@ def invoke(
115115
raise LLMGenerationError(e)
116116

117117
async def ainvoke(
118-
self, input: str, message_history: Optional[list[BaseMessage]] = None
118+
self,
119+
input: str,
120+
message_history: Optional[list[BaseMessage]] = None,
121+
system_instruction: Optional[str] = None,
119122
) -> LLMResponse:
120123
"""Asynchronously sends a text input to the OpenAI chat
121124
completion model and returns the response's content.
122125
123126
Args:
124127
input (str): Text sent to the LLM.
125128
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
129+
system_instruction (Optional[str]): An option to override the llm system message for this invokation.
126130
127131
Returns:
128132
LLMResponse: The response from OpenAI.
@@ -132,7 +136,7 @@ async def ainvoke(
132136
"""
133137
try:
134138
response = await self.async_client.chat.completions.create(
135-
messages=self.get_messages(input, message_history),
139+
messages=self.get_messages(input, message_history, system_instruction),
136140
model=self.model_name,
137141
**self.model_params,
138142
)

0 commit comments

Comments
 (0)