Skip to content

Commit 3b1c3d5

Browse files
feat: Remove Inference from Env Config
1 parent 49f12d0 commit 3b1c3d5

File tree

9 files changed

+40
-38
lines changed

9 files changed

+40
-38
lines changed

.github/scripts/fix_cdk_json.sh

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,3 @@ fix_string_key slackBotToken "${SLACK_BOT_TOKEN}"
6161
fix_string_key slackSigningSecret "${SLACK_SIGNING_SECRET}"
6262
fix_string_key cfnDriftDetectionGroup "${CFN_DRIFT_DETECTION_GROUP}"
6363
fix_boolean_number_key isPullRequest "${IS_PULL_REQUEST}"
64-
fix_boolean_number_key ragTemperature "${RAG_TEMPERATURE}"
65-
fix_boolean_number_key ragMaxTokens "${RAG_MAX_TOKENS}"
66-
fix_boolean_number_key ragTopP "${RAG_TOP_P}"

packages/slackBotFunction/app/core/config.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,6 @@ def get_retrieve_generate_config() -> Tuple[str, str, str, str, str, str, str, s
8080
GUARD_VERSION = os.environ["GUARD_RAIL_VERSION"]
8181
RAG_RESPONSE_PROMPT_NAME = os.environ["RAG_RESPONSE_PROMPT_NAME"]
8282
RAG_RESPONSE_PROMPT_VERSION = os.environ["RAG_RESPONSE_PROMPT_VERSION"]
83-
RAG_TEMPERATURE = os.environ["RAG_TEMPERATURE"]
84-
RAG_MAX_TOKENS = os.environ["RAG_MAX_TOKENS"]
85-
RAG_TOP_P = os.environ["RAG_TOP_P"]
8683

8784
logger.info(
8885
"Guardrail configuration loaded", extra={"guardrail_id": GUARD_RAIL_ID, "guardrail_version": GUARD_VERSION}
@@ -95,9 +92,6 @@ def get_retrieve_generate_config() -> Tuple[str, str, str, str, str, str, str, s
9592
GUARD_VERSION,
9693
RAG_RESPONSE_PROMPT_NAME,
9794
RAG_RESPONSE_PROMPT_VERSION,
98-
RAG_TEMPERATURE,
99-
RAG_MAX_TOKENS,
100-
RAG_TOP_P,
10195
)
10296

10397

packages/slackBotFunction/app/services/bedrock.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import json
2-
import os
32
from typing import Any
43
import boto3
54
from mypy_boto3_bedrock_agent_runtime import AgentsforBedrockRuntimeClient
@@ -29,12 +28,18 @@ def query_bedrock(user_query: str, session_id: str = None) -> RetrieveAndGenerat
2928
GUARD_VERSION,
3029
RAG_RESPONSE_PROMPT_NAME,
3130
RAG_RESPONSE_PROMPT_VERSION,
32-
RAG_TEMPERATURE,
33-
RAG_MAX_TOKENS,
34-
RAG_TOP_P,
3531
) = get_retrieve_generate_config()
3632

3733
prompt_template = load_prompt(RAG_RESPONSE_PROMPT_NAME, RAG_RESPONSE_PROMPT_VERSION)
34+
inference_config = prompt_template.get("inference_config")
35+
36+
if not inference_config:
37+
default_values = {"temperature": 0, "maxTokens": 512, "topP": 1}
38+
inference_config = default_values
39+
logger.warning(
40+
"No inference configuration found in prompt template; using default values",
41+
extra={"prompt_name": RAG_RESPONSE_PROMPT_NAME, "default_inference_config": default_values},
42+
)
3843

3944
client: AgentsforBedrockRuntimeClient = boto3.client(
4045
service_name="bedrock-agent-runtime",
@@ -54,9 +59,9 @@ def query_bedrock(user_query: str, session_id: str = None) -> RetrieveAndGenerat
5459
},
5560
"inferenceConfig": {
5661
"textInferenceConfig": {
57-
"temperature": RAG_TEMPERATURE,
58-
"topP": RAG_TOP_P,
59-
"maxTokens": RAG_MAX_TOKENS,
62+
"temperature": inference_config.get("temperature", 1),
63+
"topP": inference_config.get("topP", 1),
64+
"maxTokens": inference_config.get("maxTokens", 512),
6065
"stopSequences": [
6166
"Human:",
6267
],
@@ -70,7 +75,7 @@ def query_bedrock(user_query: str, session_id: str = None) -> RetrieveAndGenerat
7075
if prompt_template:
7176
request_params["retrieveAndGenerateConfiguration"]["knowledgeBaseConfiguration"]["generationConfiguration"][
7277
"promptTemplate"
73-
] = {"textPromptTemplate": prompt_template}
78+
] = {"textPromptTemplate": prompt_template.get("prompt_text")}
7479
logger.info(
7580
"Using prompt template for RAG response generation", extra={"prompt_name": RAG_RESPONSE_PROMPT_NAME}
7681
)
@@ -90,16 +95,16 @@ def query_bedrock(user_query: str, session_id: str = None) -> RetrieveAndGenerat
9095
return response
9196

9297

93-
def invoke_model(prompt: str, model_id: str, client: BedrockRuntimeClient) -> dict[str, Any]:
98+
def invoke_model(prompt: str, model_id: str, client: BedrockRuntimeClient, inference_config: dict) -> dict[str, Any]:
9499
response = client.invoke_model(
95100
modelId=model_id,
96101
body=json.dumps(
97102
{
98103
"anthropic_version": "bedrock-2023-05-31",
99-
"temperature": os.environ.get("RAG_TEMPERATURE", "1"),
100-
"top_p": os.environ.get("RAG_TOP_P", "1"),
104+
"temperature": inference_config.get("temperature", "1"),
105+
"top_p": inference_config.get("topP", "1"),
101106
"top_k": 50,
102-
"max_tokens": os.environ.get("RAG_MAX_TOKENS", "512"),
107+
"max_tokens": inference_config.get("maxTokens", "512"),
103108
"messages": [{"role": "user", "content": prompt}],
104109
}
105110
),

packages/slackBotFunction/app/services/prompt_loader.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def parse_system_message(chat_cfg: dict) -> str:
7171
return "\n\n".join(parts)
7272

7373

74-
def load_prompt(prompt_name: str, prompt_version: str = None) -> str:
74+
def load_prompt(prompt_name: str, prompt_version: str = None) -> dict:
7575
"""
7676
Load a prompt template from Amazon Bedrock Prompt Management.
7777
@@ -103,16 +103,18 @@ def load_prompt(prompt_name: str, prompt_version: str = None) -> str:
103103
template_config = response["variants"][0]["templateConfiguration"]
104104
prompt_text = _render_prompt(template_config)
105105
actual_version = response.get("version", "DRAFT")
106+
inference_config = response["variants"][0]["inferenceConfiguration"]
106107

107108
logger.info(
108109
f"Successfully loaded prompt '{prompt_name}' version {actual_version}",
109110
extra={
110111
"prompt_name": prompt_name,
111112
"prompt_id": prompt_id,
112113
"version_used": actual_version,
114+
"inference_config": inference_config,
113115
},
114116
)
115-
return prompt_text
117+
return {"prompt_text": prompt_text, "inference_config": inference_config}
116118

117119
except ClientError as e:
118120
error_code = e.response.get("Error", {}).get("Code", "Unknown")

packages/slackBotFunction/app/services/query_reformulator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,10 @@ def reformulate_query(user_query: str) -> str:
3838
)
3939

4040
# Format the prompt with the user query (using double braces from Bedrock template)
41-
prompt = prompt_template.replace("{{user_query}}", user_query)
42-
result = invoke_model(prompt=prompt, model_id=model_id, client=client)
41+
prompt = prompt_template.get("prompt_text").replace("{{user_query}}", user_query)
42+
result = invoke_model(
43+
prompt=prompt, model_id=model_id, client=client, inference_config=prompt_template.get("inference_config")
44+
)
4345

4446
reformulated_query = result["content"][0]["text"].strip()
4547

packages/slackBotFunction/tests/conftest.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,6 @@ def mock_env():
2626
"QUERY_REFORMULATION_PROMPT_VERSION": "DRAFT",
2727
"RAG_RESPONSE_PROMPT_NAME": "test-rag-prompt",
2828
"RAG_RESPONSE_PROMPT_VERSION": "DRAFT",
29-
"RAG_TEMPERATURE": "0.5",
30-
"RAG_MAX_TOKENS": "1024",
31-
"RAG_TOP_P": "0.9",
3229
}
3330
env_vars["AWS_DEFAULT_REGION"] = env_vars["AWS_REGION"]
3431
with patch.dict(os.environ, env_vars, clear=False):

packages/slackBotFunction/tests/test_bedrock_integration.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def test_query_bedrock_check_prompt(mock_boto_client: Mock, mock_load_prompt: Mo
8484
mock_client = Mock()
8585
mock_boto_client.return_value = mock_client
8686
mock_client.retrieve_and_generate.return_value = {"output": {"text": "response"}}
87-
mock_load_prompt.return_value = "Test prompt template"
87+
mock_load_prompt.return_value = {"prompt_text": "Test prompt template", "inference_config": {}}
8888

8989
# delete and import module to test
9090
if "app.services.bedrock" in sys.modules:
@@ -112,6 +112,10 @@ def test_query_bedrock_check_config(mock_boto_client: Mock, mock_load_prompt: Mo
112112
mock_client = Mock()
113113
mock_boto_client.return_value = mock_client
114114
mock_client.retrieve_and_generate.return_value = {"output": {"text": "response"}}
115+
mock_load_prompt.return_value = {
116+
"prompt_text": "Test prompt template",
117+
"inference_config": {"temperature": "0", "maxTokens": "512", "topP": "1"},
118+
}
115119

116120
# delete and import module to test
117121
if "app.services.bedrock" in sys.modules:
@@ -127,6 +131,6 @@ def test_query_bedrock_check_config(mock_boto_client: Mock, mock_load_prompt: Mo
127131
"generationConfiguration"
128132
]["inferenceConfig"]["textInferenceConfig"]
129133

130-
assert prompt_config["temperature"] == "0.5"
131-
assert prompt_config["maxTokens"] == "1024"
132-
assert prompt_config["topP"] == "0.9"
134+
assert prompt_config["temperature"] == "0"
135+
assert prompt_config["maxTokens"] == "512"
136+
assert prompt_config["topP"] == "1"

packages/slackBotFunction/tests/test_prompt_loader.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def test_load_prompt_success_draft(mock_boto_client: Mock, mock_env: Mock):
2020

2121
# Mock get_prompt for DRAFT version
2222
mock_client.get_prompt.return_value = {
23-
"variants": [{"templateConfiguration": {"text": {"text": "Test prompt"}}}],
23+
"variants": [{"templateConfiguration": {"text": {"text": "Test prompt"}}, "inferenceConfiguration": {}}],
2424
"version": "DRAFT",
2525
}
2626

@@ -33,7 +33,7 @@ def test_load_prompt_success_draft(mock_boto_client: Mock, mock_env: Mock):
3333
result = load_prompt("test-prompt")
3434

3535
# assertions
36-
assert result == "Test prompt"
36+
assert result.get("prompt_text") == "Test prompt"
3737
mock_client.get_prompt.assert_called_once_with(promptIdentifier="ABC1234567")
3838

3939

@@ -46,7 +46,7 @@ def test_load_prompt_success_versioned(mock_boto_client: Mock, mock_env: Mock):
4646
mock_client.list_prompts.return_value = {"promptSummaries": [{"name": "test-prompt", "id": "ABC1234567"}]}
4747

4848
mock_client.get_prompt.return_value = {
49-
"variants": [{"templateConfiguration": {"text": {"text": "Versioned prompt"}}}],
49+
"variants": [{"templateConfiguration": {"text": {"text": "Versioned prompt"}}, "inferenceConfiguration": {}}],
5050
"version": "1",
5151
}
5252

@@ -59,7 +59,7 @@ def test_load_prompt_success_versioned(mock_boto_client: Mock, mock_env: Mock):
5959
result = load_prompt("test-prompt", "1")
6060

6161
# assertions
62-
assert result == "Versioned prompt"
62+
assert result.get("prompt_text") == "Versioned prompt"
6363
mock_client.get_prompt.assert_called_once_with(promptIdentifier="ABC1234567", promptVersion="1")
6464

6565

packages/slackBotFunction/tests/test_query_reformulator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def mock_logger():
1414
def test_reformulate_query_returns_string(mock_invoke_model: Mock, mock_load_prompt: Mock, mock_env: Mock):
1515
"""Test that reformulate_query returns a string without crashing"""
1616
# set up mocks
17-
mock_load_prompt.return_value = "Test reformat. {{user_query}}"
17+
mock_load_prompt.return_value = {"prompt_text": "Test reformat. {{user_query}}", "inference_config": {}}
1818
mock_invoke_model.return_value = {"content": [{"text": "foo"}]}
1919

2020
# delete and import module to test
@@ -24,6 +24,7 @@ def test_reformulate_query_returns_string(mock_invoke_model: Mock, mock_load_pro
2424

2525
# perform operation
2626
result = reformulate_query("How do I use EPS?")
27+
result = result
2728

2829
# assertions
2930
# Function should return a string (either reformulated or fallback to original)
@@ -32,7 +33,7 @@ def test_reformulate_query_returns_string(mock_invoke_model: Mock, mock_load_pro
3233
assert result == "foo"
3334
mock_load_prompt.assert_called_once_with("test-prompt", "DRAFT")
3435
mock_invoke_model.assert_called_once_with(
35-
prompt="Test reformat. How do I use EPS?", model_id="test-model", client=ANY
36+
prompt="Test reformat. How do I use EPS?", model_id="test-model", client=ANY, inference_config={}
3637
)
3738

3839

0 commit comments

Comments
 (0)