Skip to content

Commit e39b31b

Browse files
feat: Set Inference Confic in CDK
1 parent 65a163a commit e39b31b

File tree

14 files changed

+150
-112
lines changed

14 files changed

+150
-112
lines changed

.github/scripts/fix_cdk_json.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,6 @@ fix_string_key slackBotToken "${SLACK_BOT_TOKEN}"
6161
fix_string_key slackSigningSecret "${SLACK_SIGNING_SECRET}"
6262
fix_string_key cfnDriftDetectionGroup "${CFN_DRIFT_DETECTION_GROUP}"
6363
fix_boolean_number_key isPullRequest "${IS_PULL_REQUEST}"
64+
fix_boolean_number_key ragTemperature "${RAG_TEMPERATURE}"
65+
fix_boolean_number_key ragMaxTokens "${RAG_MAX_TOKENS}"
66+
fix_boolean_number_key ragTopP "${RAG_TOP_P}"

packages/cdk/prompts/BedrockPromptsCollection.ts

Lines changed: 0 additions & 81 deletions
This file was deleted.

packages/cdk/resources/BedrockPromptResources.ts

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ import {
44
Prompt,
55
PromptVariant
66
} from "@cdklabs/generative-ai-cdk-constructs/lib/cdk-lib/bedrock"
7-
import {BedrockPromptCollection} from "../prompts/BedrockPromptsCollection"
7+
import {BedrockPromptSettings} from "./BedrockPromptSettings"
88

99
export interface BedrockPromptResourcesProps {
1010
readonly stackName: string
11-
readonly collection: BedrockPromptCollection
11+
readonly settings: BedrockPromptSettings
1212
}
1313

1414
export class BedrockPromptResources extends Construct {
@@ -25,7 +25,7 @@ export class BedrockPromptResources extends Construct {
2525
variantName: "default",
2626
model: claudeHaikuModel,
2727
promptVariables: ["topic"],
28-
promptText: props.collection.reformulationPrompt.text
28+
promptText: props.settings.reformulationPrompt.text
2929
})
3030

3131
const queryReformulationPrompt = new Prompt(this, "QueryReformulationPrompt", {
@@ -39,19 +39,12 @@ export class BedrockPromptResources extends Construct {
3939
variantName: "default",
4040
model: claudeSonnetModel,
4141
promptVariables: ["query", "search_results"],
42-
system: props.collection.systemPrompt.text,
43-
messages: [props.collection.userPrompt]
42+
system: props.settings.systemPrompt.text,
43+
messages: [props.settings.userPrompt]
4444
})
4545

4646
ragResponsePromptVariant["inferenceConfiguration"] = {
47-
"text": {
48-
"temperature": 0,
49-
"topP": 1,
50-
"maxTokens": 512,
51-
"stopSequences": [
52-
"Human:"
53-
]
54-
}
47+
"text": props.settings.inferenceConfig
5548
}
5649

5750
const ragPrompt = new Prompt(this, "ragResponsePrompt", {
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import * as fs from "fs"
2+
import {ChatMessage} from "@cdklabs/generative-ai-cdk-constructs/lib/cdk-lib/bedrock"
3+
import {Construct} from "constructs"
4+
5+
export type BedrockPromptSettingsType = "system" | "user" | "reformulation"
6+
7+
export interface BedrockPromptInferenceConfig {
8+
temperature: number,
9+
topP: number,
10+
maxTokens: number,
11+
stopSequences: Array<string>
12+
}
13+
14+
/** BedrockPromptSettings is responsible for loading and providing
15+
* the system, user, and reformulation prompts along with their
16+
* inference configurations.
17+
*/
18+
export class BedrockPromptSettings extends Construct {
19+
public readonly systemPrompt: ChatMessage
20+
public readonly userPrompt: ChatMessage
21+
public readonly reformulationPrompt: ChatMessage
22+
public readonly inferenceConfig: BedrockPromptInferenceConfig
23+
24+
/**
25+
* @param scope The Construct scope
26+
* @param id The Construct ID
27+
* @param props BedrockPromptSettingsProps containing optional version info for each prompt
28+
*/
29+
constructor(scope: Construct, id: string) {
30+
super(scope, id)
31+
32+
const systemPromptData = this.getTypedPrompt("system")
33+
this.systemPrompt = ChatMessage.assistant(systemPromptData.text)
34+
35+
const userPromptData = this.getTypedPrompt("user")
36+
this.userPrompt = ChatMessage.user(userPromptData.text)
37+
38+
const reformulationPrompt = this.getTypedPrompt("reformulation")
39+
this.reformulationPrompt = ChatMessage.user(reformulationPrompt.text)
40+
41+
const temperature = this.node.tryGetContext("ragTemperature")
42+
const maxTokens = this.node.tryGetContext("ragMaxTokens")
43+
const topP = this.node.tryGetContext("ragTopP")
44+
45+
this.inferenceConfig = {
46+
temperature: parseInt(temperature, 10),
47+
topP: parseInt(topP, 10),
48+
maxTokens: parseInt(maxTokens, 10),
49+
stopSequences: [
50+
"Human:"
51+
]
52+
}
53+
}
54+
55+
/** Get the latest prompt text from files in the specified directory.
56+
* If a version is provided, it retrieves that specific version.
57+
* Otherwise, it retrieves the latest version based on file naming.
58+
*
59+
* @param type The type of prompt (system, user, reformulation)
60+
* @returns An object containing the prompt text and filename
61+
*/
62+
private getTypedPrompt(type: BedrockPromptSettingsType)
63+
: { text: string; filename: string } {
64+
// Read all files in the folder
65+
const files = fs
66+
.readdirSync("../../../prompts")
67+
.filter(f => f.startsWith(`${type}_v`) && f.endsWith(".txt"))
68+
69+
if (files.length === 0) {
70+
throw new Error("No variant files found in the folder.")
71+
}
72+
73+
const file = files.find(file => file.startsWith(`${type}Prompt`))!
74+
75+
const text = fs.readFileSync(file, "utf-8")
76+
77+
return {text, filename: file}
78+
}
79+
}

packages/cdk/resources/Functions.ts

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ const RAG_MODEL_ID = "anthropic.claude-3-sonnet-20240229-v1:0"
1111
const QUERY_REFORMULATION_MODEL_ID = "anthropic.claude-3-haiku-20240307-v1:0"
1212
const BEDROCK_KB_DATA_SOURCE = "eps-assist-kb-ds"
1313
const LAMBDA_MEMORY_SIZE = "265"
14+
// Claude RAG inference parameters
15+
const RAG_TEMPERATURE = "0"
16+
const RAG_MAX_TOKENS = "512"
17+
const RAG_TOP_P = "1"
1418

1519
export interface FunctionsProps {
1620
readonly stackName: string
@@ -71,7 +75,10 @@ export class Functions extends Construct {
7175
"QUERY_REFORMULATION_PROMPT_NAME": props.reformulationPromptName,
7276
"RAG_RESPONSE_PROMPT_NAME": props.ragResponsePromptName,
7377
"QUERY_REFORMULATION_PROMPT_VERSION": props.reformulationPromptVersion,
74-
"RAG_RESPONSE_PROMPT_VERSION": props.ragResponsePromptVersion
78+
"RAG_RESPONSE_PROMPT_VERSION": props.ragResponsePromptVersion,
79+
"RAG_TEMPERATURE": RAG_TEMPERATURE,
80+
"RAG_MAX_TOKENS": RAG_MAX_TOKENS,
81+
"RAG_TOP_P": RAG_TOP_P
7582
}
7683
})
7784

packages/cdk/stacks/EpsAssistMeStack.ts

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import {BedrockPromptResources} from "../resources/BedrockPromptResources"
1919
import {S3LambdaNotification} from "../constructs/S3LambdaNotification"
2020
import {VectorIndex} from "../resources/VectorIndex"
2121
import {ManagedPolicy, PolicyStatement, Role} from "aws-cdk-lib/aws-iam"
22-
import {BedrockPromptCollection} from "../prompts/BedrockPromptsCollection"
22+
import {BedrockPromptSettings} from "../resources/BedrockPromptSettings"
2323

2424
export interface EpsAssistMeStackProps extends StackProps {
2525
readonly stackName: string
@@ -67,18 +67,13 @@ export class EpsAssistMeStack extends Stack {
6767
stackName: props.stackName
6868
})
6969

70-
/// TODO: Get versions during deployment - for now, default to latest
7170
// Create Bedrock Prompt Collection
72-
const bedrockPromptCollection = new BedrockPromptCollection(this, "BedrockPromptCollection", {
73-
systemPromptVersion: undefined,
74-
userPromptVersion: undefined,
75-
reformulationPromptVersion: undefined
76-
})
71+
const bedrockPromptCollection = new BedrockPromptSettings(this, "BedrockPromptCollection")
7772

7873
// Create Bedrock Prompt Resources
7974
const bedrockPromptResources = new BedrockPromptResources(this, "BedrockPromptResources", {
8075
stackName: props.stackName,
81-
collection: bedrockPromptCollection
76+
settings: bedrockPromptCollection
8277
})
8378

8479
// Create Storage construct first as it has no dependencies
@@ -105,7 +100,8 @@ export class EpsAssistMeStack extends Stack {
105100
stackName: props.stackName,
106101
collection: openSearchResources.collection
107102
})
108-
// this dependency ensures the OpenSearch access policy is created before the VectorIndex
103+
104+
// This dependency ensures the OpenSearch access policy is created before the VectorIndex
109105
// and deleted after the VectorIndex is deleted to prevent deletion or deployment failures
110106
vectorIndex.node.addDependency(openSearchResources.deploymentPolicy)
111107

packages/cdk/tsconfig.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,6 @@
2626
]
2727
},
2828
"references": [],
29-
"include": ["resources/**/*", "constructs/**/*", "stacks/**/*", "tests/**/*", "prompts/**/*", "nagSuppressions.ts"],
29+
"include": ["resources/**/*", "constructs/**/*", "stacks/**/*", "tests/**/*", "resources/prompts/**/*", "nagSuppressions.ts"],
3030
"exclude": ["node_modules", "cdk.out"]
3131
}

packages/slackBotFunction/app/core/config.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def get_bot_token() -> str:
7171

7272

7373
@lru_cache
74-
def get_retrieve_generate_config() -> Tuple[str, str, str, str, str, str, str]:
74+
def get_retrieve_generate_config() -> Tuple[str, str, str, str, str, str, str, str, str, str]:
7575
# Bedrock configuration from environment
7676
KNOWLEDGEBASE_ID = os.environ["KNOWLEDGEBASE_ID"]
7777
RAG_MODEL_ID = os.environ["RAG_MODEL_ID"]
@@ -80,6 +80,9 @@ def get_retrieve_generate_config() -> Tuple[str, str, str, str, str, str, str]:
8080
GUARD_VERSION = os.environ["GUARD_RAIL_VERSION"]
8181
RAG_RESPONSE_PROMPT_NAME = os.environ["RAG_RESPONSE_PROMPT_NAME"]
8282
RAG_RESPONSE_PROMPT_VERSION = os.environ["RAG_RESPONSE_PROMPT_VERSION"]
83+
RAG_TEMPERATURE = os.environ["RAG_TEMPERATURE"]
84+
RAG_MAX_TOKENS = os.environ["RAG_MAX_TOKENS"]
85+
RAG_TOP_P = os.environ["RAG_TOP_P"]
8386

8487
logger.info(
8588
"Guardrail configuration loaded", extra={"guardrail_id": GUARD_RAIL_ID, "guardrail_version": GUARD_VERSION}
@@ -92,6 +95,9 @@ def get_retrieve_generate_config() -> Tuple[str, str, str, str, str, str, str]:
9295
GUARD_VERSION,
9396
RAG_RESPONSE_PROMPT_NAME,
9497
RAG_RESPONSE_PROMPT_VERSION,
98+
RAG_TEMPERATURE,
99+
RAG_MAX_TOKENS,
100+
RAG_TOP_P,
95101
)
96102

97103

packages/slackBotFunction/app/services/bedrock.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import os
23
from typing import Any
34
import boto3
45
from mypy_boto3_bedrock_agent_runtime import AgentsforBedrockRuntimeClient
@@ -28,6 +29,9 @@ def query_bedrock(user_query: str, session_id: str = None) -> RetrieveAndGenerat
2829
GUARD_VERSION,
2930
RAG_RESPONSE_PROMPT_NAME,
3031
RAG_RESPONSE_PROMPT_VERSION,
32+
RAG_TEMPERATURE,
33+
RAG_MAX_TOKENS,
34+
RAG_TOP_P,
3135
) = get_retrieve_generate_config()
3236

3337
prompt_template = load_prompt(RAG_RESPONSE_PROMPT_NAME, RAG_RESPONSE_PROMPT_VERSION)
@@ -50,9 +54,9 @@ def query_bedrock(user_query: str, session_id: str = None) -> RetrieveAndGenerat
5054
},
5155
"inferenceConfig": {
5256
"textInferenceConfig": {
53-
"temperature": 0,
54-
"topP": 1,
55-
"maxTokens": 512,
57+
"temperature": RAG_TEMPERATURE,
58+
"topP": RAG_TOP_P,
59+
"maxTokens": RAG_MAX_TOKENS,
5660
"stopSequences": [
5761
"Human:",
5862
],
@@ -92,10 +96,10 @@ def invoke_model(prompt: str, model_id: str, client: BedrockRuntimeClient) -> di
9296
body=json.dumps(
9397
{
9498
"anthropic_version": "bedrock-2023-05-31",
95-
"temperature": 0.1,
96-
"top_p": 0.9,
99+
"temperature": os.environ.get("RAG_TEMPERATURE", "1"),
100+
"top_p": os.environ.get("RAG_TOP_P", "1"),
97101
"top_k": 50,
98-
"max_tokens": 150,
102+
"max_tokens": os.environ.get("RAG_MAX_TOKENS", "512"),
99103
"messages": [{"role": "user", "content": prompt}],
100104
}
101105
),

packages/slackBotFunction/tests/conftest.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ def mock_env():
2626
"QUERY_REFORMULATION_PROMPT_VERSION": "DRAFT",
2727
"RAG_RESPONSE_PROMPT_NAME": "test-rag-prompt",
2828
"RAG_RESPONSE_PROMPT_VERSION": "DRAFT",
29+
"RAG_TEMPERATURE": "0.5",
30+
"RAG_MAX_TOKENS": "1024",
31+
"RAG_TOP_P": "0.9",
2932
}
3033
env_vars["AWS_DEFAULT_REGION"] = env_vars["AWS_REGION"]
3134
with patch.dict(os.environ, env_vars, clear=False):

0 commit comments

Comments
 (0)