Skip to content

Commit 7e594f5

Browse files
committed
Reduce temperature
1 parent 2a26b69 commit 7e594f5

File tree

3 files changed

+40
-13
lines changed

3 files changed

+40
-13
lines changed

text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
3-
from autogen_agentchat.conditions import TextMentionTermination, MaxMessageTermination
3+
from autogen_agentchat.conditions import (
4+
TextMentionTermination,
5+
MaxMessageTermination,
6+
SourceMatchTermination,
7+
)
48
from autogen_agentchat.teams import SelectorGroupChat
59
from autogen_text_2_sql.creators.llm_model_creator import LLMModelCreator
610
from autogen_text_2_sql.creators.llm_agent_creator import LLMAgentCreator
@@ -89,7 +93,11 @@ def agents(self):
8993
@property
9094
def termination_condition(self):
9195
"""Define the termination condition for the chat."""
92-
termination = TextMentionTermination("TERMINATE") | MaxMessageTermination(20)
96+
termination = (
97+
TextMentionTermination("TERMINATE")
98+
| MaxMessageTermination(20)
99+
| SourceMatchTermination(["answer_agent"])
100+
)
93101
return termination
94102

95103
@staticmethod

text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_model_creator.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
33
from autogen_ext.models import AzureOpenAIChatCompletionClient
4-
from text_2_sql_core.connectors.factory import ConnectorFactory
4+
from text_2_sql_core.utils.environment import IdentityType, get_identity_type
55

6+
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
67
import os
78
import dotenv
89

@@ -27,12 +28,32 @@ def get_model(cls, model_name: str) -> AzureOpenAIChatCompletionClient:
2728
else:
2829
raise ValueError(f"Model {model_name} not found")
2930

31+
@classmethod
32+
def get_authentication_properties(cls) -> dict:
33+
if get_identity_type() == IdentityType.SYSTEM_ASSIGNED:
34+
# Create the token provider
35+
api_key = None
36+
token_provider = get_bearer_token_provider(
37+
DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
38+
)
39+
elif get_identity_type() == IdentityType.USER_ASSIGNED:
40+
# Create the token provider
41+
api_key = None
42+
token_provider = get_bearer_token_provider(
43+
DefaultAzureCredential(
44+
managed_identity_client_id=os.environ["ClientId"]
45+
),
46+
"https://cognitiveservices.azure.com/.default",
47+
)
48+
else:
49+
token_provider = None
50+
api_key = os.environ["OpenAI__ApiKey"]
51+
52+
return token_provider, api_key
53+
3054
@classmethod
3155
def gpt_4o_mini_model(cls) -> AzureOpenAIChatCompletionClient:
32-
(
33-
token_provider,
34-
api_key,
35-
) = ConnectorFactory.get_open_ai_connector().get_authentication_properties()
56+
token_provider, api_key = cls.get_authentication_properties()
3657
return AzureOpenAIChatCompletionClient(
3758
azure_deployment=os.environ["OpenAI__MiniCompletionDeployment"],
3859
model=os.environ["OpenAI__MiniCompletionDeployment"],
@@ -45,14 +66,12 @@ def gpt_4o_mini_model(cls) -> AzureOpenAIChatCompletionClient:
4566
"function_calling": True,
4667
"json_output": True,
4768
},
69+
temperature=0,
4870
)
4971

5072
@classmethod
5173
def gpt_4o_model(cls) -> AzureOpenAIChatCompletionClient:
52-
(
53-
token_provider,
54-
api_key,
55-
) = ConnectorFactory.get_open_ai_connector().get_authentication_properties()
74+
token_provider, api_key = cls.get_authentication_properties()
5675
return AzureOpenAIChatCompletionClient(
5776
azure_deployment=os.environ["OpenAI__CompletionDeployment"],
5877
model=os.environ["OpenAI__CompletionDeployment"],
@@ -65,4 +84,5 @@ def gpt_4o_model(cls) -> AzureOpenAIChatCompletionClient:
6584
"function_calling": True,
6685
"json_output": True,
6786
},
87+
temperature=0,
6888
)

text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/answer_agent.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,4 @@ system_message:
1616
}
1717
1818
Title is the entity name of the schema, chunk is the result of the SQL query and reference is the SQL query used to generate the answer.
19-
20-
End your answer with 'TERMINATE'"
19+
"

0 commit comments

Comments
 (0)