Skip to content

Commit b16fad3

Browse files
authored
Manage system prompt from Chat API (#64)
1 parent 5db68bb commit b16fad3

File tree

14 files changed

+1008
-173
lines changed

14 files changed

+1008
-173
lines changed

Cargo.lock

Lines changed: 36 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/api/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ chrono = "0.4"
5555
tokio = { version = "1", features = ["full"] }
5656
axum-test = "18"
5757
dotenvy = "0.15"
58+
wiremock = "0.6"
5859

5960
[features]
6061
default = []

crates/api/src/main.rs

Lines changed: 36 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1,80 +1,18 @@
11
use api::{create_router_with_cors, ApiDoc, AppState};
2-
use hmac::{Hmac, Mac};
32
use services::{
4-
auth::OAuthServiceImpl, conversation::service::ConversationServiceImpl,
5-
file::service::FileServiceImpl, response::service::OpenAIProxy, user::UserServiceImpl,
3+
auth::OAuthServiceImpl,
4+
conversation::service::ConversationServiceImpl,
5+
file::service::FileServiceImpl,
6+
response::service::OpenAIProxy,
7+
user::UserServiceImpl,
68
user::UserSettingsServiceImpl,
9+
vpc::{initialize_vpc_credentials, VpcAuthConfig},
710
};
8-
use sha2::Sha256;
911
use std::sync::Arc;
10-
use std::time::{SystemTime, UNIX_EPOCH};
1112
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
1213
use utoipa::OpenApi;
1314
use utoipa_swagger_ui::SwaggerUi;
1415

15-
type HmacSha256 = Hmac<Sha256>;
16-
17-
/// Response from VPC login endpoint
18-
#[derive(serde::Deserialize)]
19-
struct VpcLoginResponse {
20-
api_key: String,
21-
}
22-
23-
/// Performs VPC authentication to obtain an API key
24-
async fn vpc_authenticate(
25-
config: &config::VpcAuthConfig,
26-
base_url: &str,
27-
) -> anyhow::Result<String> {
28-
let shared_secret = config
29-
.read_shared_secret()
30-
.ok_or_else(|| anyhow::anyhow!("Failed to read VPC shared secret"))?;
31-
32-
// Generate timestamp
33-
let timestamp = SystemTime::now()
34-
.duration_since(UNIX_EPOCH)
35-
.expect("Time went backwards")
36-
.as_secs();
37-
38-
// Generate HMAC-SHA256 signature
39-
let mut mac = HmacSha256::new_from_slice(shared_secret.as_bytes())
40-
.expect("HMAC can take key of any size");
41-
mac.update(timestamp.to_string().as_bytes());
42-
let signature = hex::encode(mac.finalize().into_bytes());
43-
44-
tracing::info!(
45-
"Performing VPC authentication with client_id: {}",
46-
config.client_id
47-
);
48-
tracing::debug!("VPC auth timestamp: {}", timestamp);
49-
50-
// Build the auth URL
51-
let auth_url = format!("{}/auth/vpc/login", base_url.trim_end_matches('/'));
52-
53-
// Make authentication request
54-
let client = reqwest::Client::new();
55-
let response = client
56-
.post(&auth_url)
57-
.header("Content-Type", "application/json")
58-
.json(&serde_json::json!({
59-
"timestamp": timestamp,
60-
"signature": signature,
61-
"client_id": config.client_id
62-
}))
63-
.send()
64-
.await?;
65-
66-
if !response.status().is_success() {
67-
let status = response.status();
68-
let body = response.text().await.unwrap_or_default();
69-
anyhow::bail!("VPC authentication failed with status {}: {}", status, body);
70-
}
71-
72-
let login_response: VpcLoginResponse = response.json().await?;
73-
tracing::info!("VPC authentication successful");
74-
75-
Ok(login_response.api_key)
76-
}
77-
7816
#[tokio::main]
7917
async fn main() -> anyhow::Result<()> {
8018
// Load .env file if it exists
@@ -148,48 +86,43 @@ async fn main() -> anyhow::Result<()> {
14886
user_settings_repo as Arc<dyn services::user::ports::UserSettingsRepository>,
14987
));
15088

151-
// Get OpenAI API key - check database first, then VPC auth, then config
152-
const VPC_API_KEY_CONFIG_KEY: &str = "vpc_api_key";
153-
154-
let api_key = if config.vpc_auth.is_configured() {
89+
// Initialize VPC credentials service and get API key
90+
let vpc_auth_config = if config.vpc_auth.is_configured() {
15591
let base_url = config.openai.base_url.as_ref().ok_or_else(|| {
15692
anyhow::anyhow!("OPENAI_BASE_URL is required when using VPC authentication")
15793
})?;
15894

159-
// Check if we have a cached API key in the database
160-
match app_config_repo.get(VPC_API_KEY_CONFIG_KEY).await {
161-
Ok(Some(cached_key)) => {
162-
tracing::info!("Using cached VPC API key from database");
163-
cached_key
164-
}
165-
Ok(None) => {
166-
tracing::info!("No cached API key found, performing VPC authentication...");
167-
let new_key = vpc_authenticate(&config.vpc_auth, base_url).await?;
168-
169-
// Store the new key in the database
170-
if let Err(e) = app_config_repo.set(VPC_API_KEY_CONFIG_KEY, &new_key).await {
171-
tracing::warn!("Failed to cache VPC API key in database: {}", e);
172-
} else {
173-
tracing::info!("VPC API key cached in database");
174-
}
175-
176-
new_key
177-
}
178-
Err(e) => {
179-
tracing::warn!(
180-
"Failed to check for cached API key: {}, performing VPC auth...",
181-
e
182-
);
183-
vpc_authenticate(&config.vpc_auth, base_url).await?
184-
}
185-
}
95+
let shared_secret = config
96+
.vpc_auth
97+
.read_shared_secret()
98+
.ok_or_else(|| anyhow::anyhow!("Failed to read VPC shared secret"))?;
99+
100+
Some(VpcAuthConfig {
101+
client_id: config.vpc_auth.client_id.clone(),
102+
shared_secret,
103+
base_url: base_url.clone(),
104+
})
186105
} else {
106+
None
107+
};
108+
109+
let static_api_key = if vpc_auth_config.is_none() {
187110
tracing::info!("Using API key from environment");
188-
config.openai.api_key.clone()
111+
Some(config.openai.api_key.clone())
112+
} else {
113+
None
189114
};
190115

116+
tracing::info!("Initializing VPC credentials service...");
117+
let vpc_credentials_service = initialize_vpc_credentials(
118+
vpc_auth_config,
119+
app_config_repo.clone() as Arc<dyn services::vpc::VpcCredentialsRepository>,
120+
static_api_key,
121+
)
122+
.await?;
123+
191124
// Initialize OpenAI proxy service
192-
let mut proxy_service = OpenAIProxy::new(api_key);
125+
let mut proxy_service = OpenAIProxy::new(vpc_credentials_service.clone());
193126
if let Some(base_url) = config.openai.base_url.clone() {
194127
proxy_service = proxy_service.with_base_url(base_url);
195128
}
@@ -216,6 +149,8 @@ async fn main() -> anyhow::Result<()> {
216149
redirect_uri: config.oauth.redirect_uri,
217150
admin_domains: Arc::new(config.admin.admin_domains),
218151
user_repository: user_repo.clone(),
152+
vpc_credentials_service,
153+
cloud_api_base_url: config.openai.base_url.clone().unwrap_or_default(),
219154
};
220155

221156
// Create router with CORS support

0 commit comments

Comments
 (0)