diff --git a/crates/api/src/main.rs b/crates/api/src/main.rs index 34974d1e..82f2fd5a 100644 --- a/crates/api/src/main.rs +++ b/crates/api/src/main.rs @@ -12,6 +12,7 @@ use services::{ conversation::service::ConversationServiceImpl, file::service::FileServiceImpl, metrics::{MockMetricsService, OtlpMetricsService}, + model::service::ModelServiceImpl, response::service::OpenAIProxy, user::UserServiceImpl, user::UserSettingsServiceImpl, @@ -71,6 +72,8 @@ async fn main() -> anyhow::Result<()> { let app_config_repo = db.app_config_repository(); let near_nonce_repo = db.near_nonce_repository(); let analytics_repo = db.analytics_repository(); + let system_configs_repo = db.system_configs_repository(); + let model_repo = db.model_repository(); // Create services tracing::info!("Initializing services..."); @@ -89,9 +92,9 @@ async fn main() -> anyhow::Result<()> { let user_service = Arc::new(UserServiceImpl::new(user_repo.clone())); - let user_settings_service = Arc::new(UserSettingsServiceImpl::new( - user_settings_repo as Arc, - )); + let user_settings_service = Arc::new(UserSettingsServiceImpl::new(user_settings_repo)); + + let model_service = Arc::new(ModelServiceImpl::new(model_repo)); // Initialize VPC credentials service and get API key let vpc_auth_config = if config.vpc_auth.is_configured() { @@ -150,6 +153,15 @@ async fn main() -> anyhow::Result<()> { analytics_repo as Arc, )); + // Initialize system configs service + tracing::info!("Initializing system configs service..."); + let system_configs_service = Arc::new( + services::system_configs::service::SystemConfigsServiceImpl::new( + system_configs_repo + as Arc, + ), + ); + // Initialize metrics service tracing::info!("Initializing metrics service..."); let metrics_service: Arc = @@ -209,6 +221,8 @@ async fn main() -> anyhow::Result<()> { oauth_service, user_service, user_settings_service, + model_service, + system_configs_service, session_repository: session_repo, proxy_service, conversation_service, @@ -222,6 +236,7 @@ async fn main() -> anyhow::Result<()> { analytics_service, near_rpc_url: config.near.rpc_url.clone(), near_balance_cache: Arc::new(tokio::sync::RwLock::new(std::collections::HashMap::new())), + model_settings_cache: Arc::new(tokio::sync::RwLock::new(std::collections::HashMap::new())), }; // Create router with CORS support diff --git a/crates/api/src/models.rs b/crates/api/src/models.rs index 86fda17e..408fd381 100644 --- a/crates/api/src/models.rs +++ b/crates/api/src/models.rs @@ -191,16 +191,6 @@ impl From for UserSettingsContent { } } -/// User settings response -#[derive(Debug, Serialize, Deserialize, ToSchema)] -pub struct UserSettingsResponse { - /// User ID - pub user_id: UserId, - /// Settings content (serialized as "settings") - #[serde(rename = "settings")] - pub content: UserSettingsContent, -} - /// User settings update request #[derive(Debug, Serialize, Deserialize, ToSchema)] pub struct UpdateUserSettingsRequest { @@ -251,6 +241,124 @@ impl UpdateUserSettingsPartiallyRequest { } } +/// User settings response +#[derive(Debug, Serialize, Deserialize, ToSchema)] +pub struct UserSettingsResponse { + /// User ID + pub user_id: UserId, + /// Settings content (serialized as "settings") + #[serde(rename = "settings")] + pub content: UserSettingsContent, +} + +/// Model settings content for API responses +#[derive(Debug, Serialize, Deserialize, ToSchema)] +pub struct ModelSettings { + /// Whether models are public (visible/usable in responses) + pub public: bool, + /// Optional system-level system prompt for this model + #[serde(skip_serializing_if = "Option::is_none")] + pub system_prompt: Option, +} + +impl From for ModelSettings { + fn from(content: services::model::ports::ModelSettings) -> Self { + Self { + public: content.public, + system_prompt: content.system_prompt, + } + } +} + +impl From for services::model::ports::ModelSettings { + fn from(content: ModelSettings) -> Self { + Self { + public: content.public, + system_prompt: content.system_prompt, + } + } +} + +/// Partial model settings for API requests +#[derive(Debug, Serialize, Deserialize, ToSchema)] +pub struct PartialModelSettings { + /// Whether models are public (visible/usable in responses) + pub public: Option, + /// Optional system-level system prompt for this model + pub system_prompt: Option, +} + +impl From for PartialModelSettings { + fn from(content: services::model::ports::PartialModelSettings) -> Self { + Self { + public: content.public, + system_prompt: content.system_prompt, + } + } +} + +impl From for services::model::ports::PartialModelSettings { + fn from(content: PartialModelSettings) -> Self { + Self { + public: content.public, + system_prompt: content.system_prompt, + } + } +} + +/// Complete model response +#[derive(Debug, Serialize, Deserialize, ToSchema)] +pub struct ModelResponse { + /// External model identifier (e.g. "gpt-4.1") + pub model_id: String, + /// Settings stored for this model + pub settings: ModelSettings, +} + +impl From for ModelResponse { + fn from(model: services::model::ports::Model) -> Self { + Self { + model_id: model.model_id, + settings: model.settings.into(), + } + } +} + +/// Model upsert request +#[derive(Debug, Serialize, Deserialize, ToSchema)] +pub struct UpsertModelsRequest { + pub settings: ModelSettings, +} + +/// Model settings update request (partial update) +#[derive(Debug, Serialize, Deserialize, ToSchema)] +pub struct UpdateModelRequest { + pub settings: Option, +} + +/// Batch model upsert request +/// +/// Maps model_id to partial settings to update. +/// Example: { "gpt-4": { "public": true }, "gpt-3.5": { "public": false, "system_prompt": "..." } } +#[derive(Debug, Serialize, Deserialize, ToSchema)] +pub struct BatchUpsertModelsRequest { + #[serde(flatten)] + pub models: std::collections::HashMap, +} + +/// Model list response with pagination +#[derive(Debug, Serialize, Deserialize, ToSchema)] +pub struct ModelListResponse { + /// List of models + pub models: Vec, + /// Maximum number of items returned + pub limit: i64, + /// Number of items skipped + pub offset: i64, + /// Total number of models + pub total: i64, +} + /// Paginated user list response #[derive(Debug, Serialize, Deserialize, ToSchema)] pub struct UserListResponse { @@ -264,6 +372,54 @@ pub struct UserListResponse { pub total: u64, } +/// System configs response +#[derive(Debug, Serialize, Deserialize, ToSchema)] +pub struct SystemConfigsResponse { + /// Default model identifier to use when not specified + #[serde(skip_serializing_if = "Option::is_none")] + pub default_model: Option, +} + +impl From for SystemConfigsResponse { + fn from(config: services::system_configs::ports::SystemConfigs) -> Self { + Self { + default_model: config.default_model, + } + } +} + +/// System configs upsert request (full replace) +#[derive(Debug, Serialize, Deserialize, ToSchema)] +pub struct UpsertSystemConfigsRequest { + /// Default model identifier to use when not specified + #[serde(skip_serializing_if = "Option::is_none")] + pub default_model: Option, +} + +impl From for services::system_configs::ports::SystemConfigs { + fn from(req: UpsertSystemConfigsRequest) -> Self { + services::system_configs::ports::SystemConfigs { + default_model: req.default_model, + } + } +} + +/// System configs update request (partial) +#[derive(Debug, Serialize, Deserialize, ToSchema)] +pub struct UpdateSystemConfigsRequest { + /// Default model identifier to use when not specified + #[serde(skip_serializing_if = "Option::is_none")] + pub default_model: Option, +} + +impl From for services::system_configs::ports::PartialSystemConfigs { + fn from(req: UpdateSystemConfigsRequest) -> Self { + services::system_configs::ports::PartialSystemConfigs { + default_model: req.default_model, + } + } +} + /// File list response with pagination #[derive(Debug, Serialize, Deserialize, ToSchema)] pub struct FileListResponse { diff --git a/crates/api/src/openapi.rs b/crates/api/src/openapi.rs index 80f759b7..edaec15a 100644 --- a/crates/api/src/openapi.rs +++ b/crates/api/src/openapi.rs @@ -24,6 +24,12 @@ use utoipa::OpenApi; crate::routes::users::get_current_user, // Admin endpoints crate::routes::admin::list_users, + crate::routes::admin::list_models, + crate::routes::admin::batch_upsert_models, + crate::routes::admin::delete_model, + crate::routes::admin::upsert_system_configs, + // Configs endpoints + crate::routes::configs::get_system_configs, crate::routes::users::get_user_settings, crate::routes::users::update_user_settings_partially, crate::routes::users::update_user_settings, @@ -47,6 +53,14 @@ use utoipa::OpenApi; crate::models::UserSettingsResponse, crate::models::UpdateUserSettingsPartiallyRequest, crate::models::UpdateUserSettingsRequest, + // Model settings / model admin models + crate::models::ModelResponse, + crate::models::ModelListResponse, + crate::models::BatchUpsertModelsRequest, + crate::models::UpdateModelRequest, + // System configs models + crate::models::SystemConfigsResponse, + crate::models::UpdateSystemConfigsRequest, // Attestation models crate::models::ApiGatewayAttestation, crate::models::ModelAttestation, @@ -58,6 +72,7 @@ use utoipa::OpenApi; (name = "Auth", description = "OAuth authentication endpoints"), (name = "Users", description = "User profile management endpoints"), (name = "Admin", description = "Admin management endpoints"), + (name = "Configs", description = "System configuration endpoints"), (name = "attestation", description = "Attestation reporting endpoints for TEE verification") ) )] diff --git a/crates/api/src/routes/admin.rs b/crates/api/src/routes/admin.rs index f841d65e..c9c13cb3 100644 --- a/crates/api/src/routes/admin.rs +++ b/crates/api/src/routes/admin.rs @@ -1,12 +1,14 @@ use crate::{consts::LIST_USERS_LIMIT_MAX, error::ApiError, models::*, state::AppState}; use axum::{ extract::{Path, Query, State}, - routing::get, + http::StatusCode, + routing::{delete, get, patch}, Json, Router, }; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use services::analytics::{ActivityLogEntry, AnalyticsSummary, TopActiveUsersResponse}; +use services::model::ports::{UpdateModelParams, UpsertModelParams}; use services::UserId; /// Pagination query parameters @@ -102,196 +104,6 @@ pub async fn list_users( })) } -#[derive(Debug, Deserialize, Serialize, utoipa::ToSchema)] -pub struct SystemPromptRequest { - pub system_prompt: Option, -} - -#[derive(Debug, Deserialize, Serialize, utoipa::ToSchema)] -pub struct SystemPromptResponse { - pub system_prompt: Option, -} - -#[derive(Debug, Deserialize, Serialize)] -struct CloudApiSettingsResponse { - settings: CloudApiSettings, -} - -#[derive(Debug, Deserialize, Serialize)] -struct CloudApiSettings { - system_prompt: Option, -} - -#[derive(Debug, Serialize)] -struct CloudApiPatchRequest { - system_prompt: Option, -} - -/// Get system prompt for the organization -/// -/// Fetches the system prompt from Cloud API. Requires admin authentication. -/// Uses the VPC session token to authenticate with Cloud API. -#[utoipa::path( - get, - path = "/v1/admin/system_prompt", - tag = "Admin", - responses( - (status = 200, description = "System prompt retrieved", body = SystemPromptResponse), - (status = 401, description = "Unauthorized", body = crate::error::ApiErrorResponse), - (status = 403, description = "Forbidden - Admin access required", body = crate::error::ApiErrorResponse), - (status = 500, description = "Internal server error", body = crate::error::ApiErrorResponse) - ), - security( - ("session_token" = []) - ) -)] -pub async fn get_system_prompt( - State(app_state): State, -) -> Result, ApiError> { - if app_state.cloud_api_base_url.is_empty() { - tracing::error!("Cloud API base URL not configured"); - return Err(ApiError::internal_server_error("Cloud API not configured")); - } - - let credentials = app_state - .vpc_credentials_service - .get_credentials() - .await - .map_err(|e| { - tracing::error!("Failed to get VPC credentials: {}", e); - ApiError::internal_server_error("Failed to authenticate with VPC") - })? - .ok_or_else(|| { - tracing::error!("VPC not configured"); - ApiError::internal_server_error("VPC authentication not configured") - })?; - - let url = format!( - "{}/organizations/{}/settings", - app_state.cloud_api_base_url.trim_end_matches('/'), - credentials.organization_id - ); - - let client = reqwest::Client::new(); - let response = client - .get(&url) - .header( - "Authorization", - format!("Bearer {}", credentials.access_token), - ) - .send() - .await - .map_err(|e| { - tracing::error!("Failed to call Cloud API: {}", e); - ApiError::bad_gateway("Failed to connect to Cloud API") - })?; - - if !response.status().is_success() { - let status = response.status(); - tracing::error!("Cloud API error: status {}", status); - return Err(ApiError::internal_server_error(format!( - "Cloud API returned error: {}", - status - ))); - } - - let settings: CloudApiSettingsResponse = response.json().await.map_err(|e| { - tracing::error!("Failed to parse Cloud API response: {}", e); - ApiError::internal_server_error("Failed to parse Cloud API response") - })?; - - Ok(Json(SystemPromptResponse { - system_prompt: settings.settings.system_prompt, - })) -} - -/// Set system prompt for the organization -/// -/// Updates the system prompt in Cloud API. Requires admin authentication. -/// Uses the VPC session token to authenticate with Cloud API. -#[utoipa::path( - post, - path = "/v1/admin/system_prompt", - tag = "Admin", - request_body = SystemPromptRequest, - responses( - (status = 200, description = "System prompt updated", body = SystemPromptResponse), - (status = 400, description = "Bad request", body = crate::error::ApiErrorResponse), - (status = 401, description = "Unauthorized", body = crate::error::ApiErrorResponse), - (status = 403, description = "Forbidden - Admin access required", body = crate::error::ApiErrorResponse), - (status = 500, description = "Internal server error", body = crate::error::ApiErrorResponse) - ), - security( - ("session_token" = []) - ) -)] -pub async fn set_system_prompt( - State(app_state): State, - Json(request): Json, -) -> Result, ApiError> { - if app_state.cloud_api_base_url.is_empty() { - tracing::error!("Cloud API base URL not configured"); - return Err(ApiError::internal_server_error("Cloud API not configured")); - } - - let credentials = app_state - .vpc_credentials_service - .get_credentials() - .await - .map_err(|e| { - tracing::error!("Failed to get VPC credentials: {}", e); - ApiError::internal_server_error("Failed to authenticate with VPC") - })? - .ok_or_else(|| { - tracing::error!("VPC not configured"); - ApiError::internal_server_error("VPC authentication not configured") - })?; - - let url = format!( - "{}/organizations/{}/settings", - app_state.cloud_api_base_url.trim_end_matches('/'), - credentials.organization_id - ); - - let patch_request = CloudApiPatchRequest { - system_prompt: request.system_prompt, - }; - - let client = reqwest::Client::new(); - let response = client - .patch(&url) - .header( - "Authorization", - format!("Bearer {}", credentials.access_token), - ) - .header("Content-Type", "application/json") - .json(&patch_request) - .send() - .await - .map_err(|e| { - tracing::error!("Failed to call Cloud API: {}", e); - ApiError::bad_gateway("Failed to connect to Cloud API") - })?; - - if !response.status().is_success() { - let status = response.status(); - tracing::error!("Cloud API error: status {}", status); - return Err(ApiError::internal_server_error(format!( - "Cloud API returned error: {}", - status - ))); - } - - let settings: CloudApiSettingsResponse = response.json().await.map_err(|e| { - tracing::error!("Failed to parse Cloud API response: {}", e); - ApiError::internal_server_error("Failed to parse Cloud API response") - })?; - - Ok(Json(SystemPromptResponse { - system_prompt: settings.settings.system_prompt, - })) -} - /// Query parameters for analytics endpoint #[derive(Debug, Deserialize)] pub struct AnalyticsQuery { @@ -489,15 +301,310 @@ pub async fn get_top_users( })) } +/// List all models with pagination +/// +/// Returns a paginated list of all models. Requires admin authentication. +#[utoipa::path( + get, + path = "/v1/admin/models", + tag = "Admin", + params( + ("limit" = i64, Query, description = "Maximum number of items to return"), + ("offset" = i64, Query, description = "Number of items to skip") + ), + responses( + (status = 200, description = "List of models", body = ModelListResponse), + (status = 400, description = "Bad request", body = crate::error::ApiErrorResponse), + (status = 401, description = "Unauthorized", body = crate::error::ApiErrorResponse), + (status = 403, description = "Forbidden - Admin access required", body = crate::error::ApiErrorResponse), + (status = 500, description = "Internal server error", body = crate::error::ApiErrorResponse) + ), + security( + ("session_token" = []) + ) +)] +pub async fn list_models( + State(app_state): State, + Query(pagination): Query, +) -> Result, ApiError> { + pagination.validate()?; + + tracing::info!( + "Listing models with limit={} and offset={}", + pagination.limit, + pagination.offset + ); + + let (models, total) = app_state + .model_service + .list_models(pagination.limit, pagination.offset) + .await + .map_err(|e| { + tracing::error!("Failed to list models: {}", e); + ApiError::internal_server_error("Failed to list models") + })?; + + Ok(Json(ModelListResponse { + models: models.into_iter().map(Into::into).collect(), + limit: pagination.limit, + offset: pagination.offset, + total, + })) +} + +/// Batch create or update models +/// +/// Creates new models or updates existing ones in batch. The request body should be a JSON object +/// where keys are model IDs and values are partial settings to update. +/// +/// Example: +/// ```json +/// { +/// "gpt-4": { "public": true, "system_prompt": "..." }, +/// "gpt-3.5": { "public": false } +/// } +/// ``` +/// +/// If a model doesn't exist, missing fields will use default values. +/// If a model exists, only provided fields will be updated. +/// Requires admin authentication. +#[utoipa::path( + patch, + path = "/v1/admin/models", + tag = "Admin", + request_body = BatchUpsertModelsRequest, + responses( + (status = 200, description = "Models created or updated", body = Vec), + (status = 400, description = "Bad request", body = crate::error::ApiErrorResponse), + (status = 401, description = "Unauthorized", body = crate::error::ApiErrorResponse), + (status = 403, description = "Forbidden - Admin access required", body = crate::error::ApiErrorResponse), + (status = 500, description = "Internal server error", body = crate::error::ApiErrorResponse) + ), + security( + ("session_token" = []) + ) +)] +pub async fn batch_upsert_models( + State(app_state): State, + Json(request): Json, +) -> Result>, ApiError> { + if request.models.is_empty() { + return Err(ApiError::bad_request("At least one model must be provided")); + } + + tracing::info!("Batch upserting {} models", request.models.len()); + + use services::model::ports::{ModelSettings, PartialModelSettings}; + + let mut results = Vec::new(); + + for (model_id, partial_settings) in request.models { + if model_id.trim().is_empty() { + return Err(ApiError::bad_request("model_id cannot be empty")); + } + + // Validate system prompt length if provided + if let Some(ref system_prompt) = partial_settings.system_prompt { + if system_prompt.len() > crate::consts::SYSTEM_PROMPT_MAX_LEN { + return Err(ApiError::bad_request(format!( + "System prompt for model '{}' exceeds maximum length of {} bytes", + model_id, + crate::consts::SYSTEM_PROMPT_MAX_LEN + ))); + } + } + + // Check if model exists + let existing_model = app_state + .model_service + .get_model(&model_id) + .await + .map_err(|e| { + tracing::error!("Failed to check if model exists: {}", e); + ApiError::internal_server_error("Failed to check if model exists") + })?; + + let model = if existing_model.is_some() { + // Model exists: partial update + let settings: PartialModelSettings = partial_settings.into(); + let params = UpdateModelParams { + model_id: model_id.clone(), + settings: Some(settings), + }; + + app_state + .model_service + .update_model(params) + .await + .map_err(|e| { + tracing::error!("Failed to update model {}: {}", model_id, e); + ApiError::internal_server_error(format!("Failed to update model {}", model_id)) + })? + } else { + // Model doesn't exist: create with defaults + provided partial settings + let default_settings = ModelSettings::default(); + let partial: PartialModelSettings = partial_settings.into(); + let full_settings = default_settings.into_updated(partial); + + let params = UpsertModelParams { + model_id: model_id.clone(), + settings: full_settings, + }; + + app_state + .model_service + .upsert_model(params) + .await + .map_err(|e| { + tracing::error!("Failed to create model {}: {}", model_id, e); + ApiError::internal_server_error(format!("Failed to create model {}", model_id)) + })? + }; + + // Invalidate cache immediately after each successful DB write + // NOTE: This only invalidates cache on the current instance. In multi-instance deployments, + // other instances may serve stale data for up to MODEL_SETTINGS_CACHE_TTL_SECS. + { + let mut cache = app_state.model_settings_cache.write().await; + cache.remove(&model_id); + } + + results.push(model.clone()); + } + + Ok(Json(results.into_iter().map(Into::into).collect())) +} + +/// Delete a model +/// +/// Deletes a specific model and its settings. Requires admin authentication. +#[utoipa::path( + delete, + path = "/v1/admin/models/{model_id}", + tag = "Admin", + params( + ("model_id" = String, Path, description = "Model identifier (e.g. gpt-4.1)") + ), + responses( + (status = 204, description = "Model deleted"), + (status = 401, description = "Unauthorized", body = crate::error::ApiErrorResponse), + (status = 403, description = "Forbidden - Admin access required", body = crate::error::ApiErrorResponse), + (status = 404, description = "Model not found", body = crate::error::ApiErrorResponse), + (status = 500, description = "Internal server error", body = crate::error::ApiErrorResponse) + ), + security( + ("session_token" = []) + ) +)] +pub async fn delete_model( + State(app_state): State, + Path(model_id): Path, +) -> Result { + if model_id.trim().is_empty() { + return Err(ApiError::bad_request("model_id cannot be empty")); + } + + tracing::info!("Deleting model for model_id={}", model_id); + + let deleted = app_state + .model_service + .delete_model(&model_id) + .await + .map_err(|e| { + tracing::error!("Failed to delete model: {}", e); + ApiError::internal_server_error("Failed to delete model") + })?; + + if !deleted { + return Err(ApiError::not_found("Model not found")); + } + + // Invalidate cache AFTER successful DB delete + { + let mut cache = app_state.model_settings_cache.write().await; + cache.remove(&model_id); + } + + Ok(StatusCode::NO_CONTENT) +} + +/// Create or update system configs +/// +/// Creates new system configs or updates existing ones. All fields in the request are optional. +/// If the configs don't exist, missing fields will use default values. +/// If the configs exist, only provided fields will be updated. +/// Requires admin authentication. +#[utoipa::path( + patch, + path = "/v1/admin/configs", + tag = "Admin", + request_body = UpdateSystemConfigsRequest, + responses( + (status = 200, description = "System configs created or updated", body = SystemConfigsResponse), + (status = 400, description = "Bad request", body = crate::error::ApiErrorResponse), + (status = 401, description = "Unauthorized", body = crate::error::ApiErrorResponse), + (status = 403, description = "Forbidden - Admin access required", body = crate::error::ApiErrorResponse), + (status = 500, description = "Internal server error", body = crate::error::ApiErrorResponse) + ), + security( + ("session_token" = []) + ) +)] +pub async fn upsert_system_configs( + State(app_state): State, + Json(request): Json, +) -> Result, ApiError> { + tracing::info!("Upserting system configs"); + + let partial: services::system_configs::ports::PartialSystemConfigs = request.into(); + + // Check if configs exist + let existing_configs = app_state + .system_configs_service + .get_configs() + .await + .map_err(|e| { + tracing::error!("Failed to check if system configs exist: {}", e); + ApiError::internal_server_error("Failed to check if system configs exist") + })?; + + let updated = if existing_configs.is_some() { + // Configs exist: partial update + app_state + .system_configs_service + .update_configs(partial) + .await + .map_err(|e| { + tracing::error!(error = ?e, "Failed to update system configs"); + ApiError::internal_server_error("Failed to update system configs") + })? + } else { + // Configs don't exist: create with defaults + provided partial configs + use services::system_configs::ports::SystemConfigs; + let default_configs = SystemConfigs::default(); + let full_configs = default_configs.into_updated(partial); + + app_state + .system_configs_service + .upsert_configs(full_configs) + .await + .map_err(|e| { + tracing::error!(error = ?e, "Failed to create system configs"); + ApiError::internal_server_error("Failed to create system configs") + })? + }; + + Ok(Json(updated.into())) +} + /// Create admin router with all admin routes (requires admin authentication) pub fn create_admin_router() -> Router { Router::new() .route("/users", get(list_users)) .route("/users/{user_id}/activity", get(get_user_activity)) - .route( - "/system_prompt", - get(get_system_prompt).post(set_system_prompt), - ) + .route("/models", get(list_models).patch(batch_upsert_models)) + .route("/models/{model_id}", delete(delete_model)) + .route("/configs", patch(upsert_system_configs)) .route("/analytics", get(get_analytics)) .route("/analytics/top-users", get(get_top_users)) } diff --git a/crates/api/src/routes/api.rs b/crates/api/src/routes/api.rs index 299d2c00..d594ee27 100644 --- a/crates/api/src/routes/api.rs +++ b/crates/api/src/routes/api.rs @@ -12,9 +12,11 @@ use bytes::Bytes; use chrono::{Duration, Utc}; use flate2::read::GzDecoder; use futures::TryStreamExt; +use http::{HeaderName, HeaderValue}; use near_api::{Account, AccountId, NetworkConfig}; use serde::{Deserialize, Serialize}; use services::analytics::{ActivityType, RecordActivityRequest}; +use services::consts::MODEL_PUBLIC_DEFAULT; use services::conversation::ports::ConversationError; use services::file::ports::FileError; use services::metrics::consts::{ @@ -23,6 +25,7 @@ use services::metrics::consts::{ use services::response::ports::ProxyResponse; use services::user::ports::{BanType, OAuthProvider}; use services::UserId; +use sha2::{Digest, Sha256}; use std::io::Read; /// Minimum required NEAR balance (1 NEAR in yoctoNEAR: 10^24) @@ -34,6 +37,9 @@ const NEAR_BALANCE_BAN_DURATION_SECS: i64 = 60 * 60; /// Duration to cache NEAR balance checks in memory (in seconds) const NEAR_BALANCE_CACHE_TTL_SECS: i64 = 5 * 60; +/// Duration to cache model settings needed by /v1/responses in memory (in seconds) +const MODEL_SETTINGS_CACHE_TTL_SECS: i64 = 60; + /// Error message when a user is banned pub const USER_BANNED_ERROR_MESSAGE: &str = "Access temporarily restricted. Please try again later."; @@ -1059,7 +1065,7 @@ async fn get_file_content( async fn proxy_responses( State(state): State, Extension(user): Extension, - headers: HeaderMap, + mut headers: HeaderMap, request: Request, ) -> Result { tracing::info!( @@ -1078,6 +1084,15 @@ async fn proxy_responses( // Extract body bytes let body_bytes = extract_body_bytes(request).await?; + // Parsed JSON body (if applicable) + let mut body_json: Option = None; + // Optional system prompt resolved from model settings + let mut model_system_prompt: Option = None; + // Model id from request body, if present + let mut model_id_from_body: Option = None; + // Whether model settings came from cache + let mut model_settings_cache_hit: Option = None; + tracing::debug!( "Extracted request body: {} bytes for POST /v1/responses", body_bytes.len() @@ -1087,11 +1102,175 @@ async fn proxy_responses( if let Ok(body_str) = std::str::from_utf8(&body_bytes) { tracing::debug!("Request body content: {}", body_str); } + + // Try to parse JSON body once for further processing (model visibility + system prompt) + match serde_json::from_slice::(&body_bytes) { + Ok(v) => { + body_json = Some(v); + } + Err(e) => { + tracing::debug!( + "Failed to parse /responses request body as JSON for user_id={}: {}", + user.user_id, + e + ); + } + } + } + + // Enforce model-level visibility based on settings if a model is specified + if let Some(ref body) = body_json { + if let Some(model_id) = body.get("model").and_then(|v| v.as_str()) { + model_id_from_body = Some(model_id.to_string()); + + // 1) Try cache first + { + let cache = state.model_settings_cache.read().await; + if let Some(entry) = cache.get(model_id) { + let age = Utc::now().signed_duration_since(entry.last_checked_at); + if age.num_seconds() >= 0 && age.num_seconds() < MODEL_SETTINGS_CACHE_TTL_SECS { + model_settings_cache_hit = Some(true); + + if !entry.public { + tracing::warn!( + "Blocking response request for non-public model '{}' from user {} (cache)", + model_id, + user.user_id + ); + return Err(( + StatusCode::FORBIDDEN, + Json(ErrorResponse { + error: "This model is not available".to_string(), + }), + ) + .into_response()); + } + + model_system_prompt = entry.system_prompt.clone(); + } + } + } + + // 2) Cache miss or expired: fetch from DB/service and populate cache + if model_settings_cache_hit.is_none() { + model_settings_cache_hit = Some(false); + match state.model_service.get_model(model_id).await { + Ok(Some(model)) => { + // Populate cache + { + let mut cache = state.model_settings_cache.write().await; + cache.insert( + model_id.to_string(), + crate::state::ModelSettingsCacheEntry { + last_checked_at: Utc::now(), + exists: true, + public: model.settings.public, + system_prompt: model.settings.system_prompt.clone(), + }, + ); + } + + if !model.settings.public { + tracing::warn!( + "Blocking response request for non-public model '{}' from user {}", + model_id, + user.user_id + ); + return Err(( + StatusCode::FORBIDDEN, + Json(ErrorResponse { + error: "This model is not available".to_string(), + }), + ) + .into_response()); + } + + model_system_prompt = model.settings.system_prompt.clone(); + } + Ok(None) => { + // Model not in admin DB - allow by default per MODEL_PUBLIC_DEFAULT + // Cache with defaults to avoid repeated DB hits, let OpenAI validate model existence + { + let mut cache = state.model_settings_cache.write().await; + cache.insert( + model_id.to_string(), + crate::state::ModelSettingsCacheEntry { + last_checked_at: Utc::now(), + exists: false, // Not in DB but allowed with defaults + public: MODEL_PUBLIC_DEFAULT, // true by default + system_prompt: None, + }, + ); + } + + // Continue with defaults - let OpenAI validate model existence + model_system_prompt = None; + } + Err(_) => { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "Failed to get model".to_string(), + }), + ) + .into_response()) + } + } + } + } + } + + // If we have a model-level system prompt, inject or prepend it into the request as `instructions`. + let modified_body_bytes = + if let (Some(mut body), Some(system_prompt)) = (body_json, model_system_prompt.clone()) { + // If client already provided instructions, prepend model system prompt with two newlines. + let new_instructions = match body.get("instructions").and_then(|v| v.as_str()) { + Some(existing) if !existing.is_empty() => { + format!("{system_prompt}\n\n{existing}") + } + _ => system_prompt, + }; + + body["instructions"] = serde_json::Value::String(new_instructions); + + match serde_json::to_vec(&body) { + Ok(serialized) => Bytes::from(serialized), + Err(_) => { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "Failed to modify instructions".to_string(), + }), + ) + .into_response()) + } + } + } else { + body_bytes + }; + + // Set content-length header + match HeaderValue::from_str(&modified_body_bytes.len().to_string()) { + Ok(header_value) => { + headers.insert("content-length", header_value); + } + Err(e) => { + tracing::error!("Failed to create content-length header value: {}", e); + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "Failed to set content-length header".to_string(), + }), + ) + .into_response()); + } } // Track conversation from the request tracing::debug!("POST to /responses detected, attempting to track conversation"); - if let Err(e) = track_conversation_from_request(&state, user.user_id, &body_bytes).await { + if let Err(e) = + track_conversation_from_request(&state, user.user_id, &modified_body_bytes).await + { tracing::error!( "Failed to track conversation for user {} from /responses: {}", user.user_id, @@ -1108,7 +1287,12 @@ async fn proxy_responses( // Forward the request to OpenAI let proxy_response = state .proxy_service - .forward_request(Method::POST, "responses", headers.clone(), Some(body_bytes)) + .forward_request( + Method::POST, + "responses", + headers.clone(), + Some(modified_body_bytes), + ) .await .map_err(|e| { tracing::error!( @@ -1131,6 +1315,41 @@ async fn proxy_responses( user.user_id ); + // Access model system prompt cache during proxy response handling (for observability/debugging). + // We DO NOT expose the prompt itself; we only attach a stable hash + cache hit indicator. + let mut proxy_headers = proxy_response.headers.clone(); + if let Some(ref model_id) = model_id_from_body { + let cached_prompt_opt = { + let cache = state.model_settings_cache.read().await; + cache.get(model_id).and_then(|e| { + if e.exists { + e.system_prompt.clone() + } else { + None + } + }) + }; + + if let Some(prompt) = cached_prompt_opt { + let mut hasher = Sha256::new(); + hasher.update(prompt.as_bytes()); + let prompt_hash = format!("{:x}", hasher.finalize()); + + let _ = proxy_headers.insert( + HeaderName::from_static("x-nearai-model-system-prompt-sha256"), + HeaderValue::from_str(&prompt_hash) + .unwrap_or_else(|_| HeaderValue::from_static("")), + ); + } + + if let Some(hit) = model_settings_cache_hit { + let _ = proxy_headers.insert( + HeaderName::from_static("x-nearai-model-settings-cache"), + HeaderValue::from_static(if hit { "hit" } else { "miss" }), + ); + } + } + // Record metrics for successful responses if (200..300).contains(&proxy_response.status) { state @@ -1144,7 +1363,20 @@ async fn proxy_responses( user_id: user.user_id, activity_type: ActivityType::Response, auth_method: None, - metadata: None, + metadata: model_id_from_body.as_ref().map(|model_id| { + let mut meta = serde_json::Map::new(); + meta.insert( + "model_id".to_string(), + serde_json::Value::String(model_id.clone()), + ); + if let Some(hit) = model_settings_cache_hit { + meta.insert( + "model_settings_cache_hit".to_string(), + serde_json::Value::Bool(hit), + ); + } + serde_json::Value::Object(meta) + }), }) .await { @@ -1154,7 +1386,7 @@ async fn proxy_responses( build_response( proxy_response.status, - proxy_response.headers, + proxy_headers, Body::from_stream(proxy_response.body), ) .await @@ -1420,12 +1652,142 @@ async fn proxy_model_list( user.user_id ); - build_response( - proxy_response.status, - proxy_response.headers, - Body::from_stream(proxy_response.body), - ) - .await + // If upstream returned non-success, just proxy as-is + if !(200..300).contains(&proxy_response.status) { + return build_response( + proxy_response.status, + proxy_response.headers, + Body::from_stream(proxy_response.body), + ) + .await; + } + + // Buffer body into bytes + let body_bytes: Bytes = proxy_response + .body + .try_collect::>() + .await + .map_err(|e| { + tracing::error!( + "Failed to read model list response body for user_id={}: {}", + user.user_id, + e + ); + ( + StatusCode::BAD_GATEWAY, + Json(ErrorResponse { + error: format!("Failed to read response body: {e}"), + }), + ) + .into_response() + })? + .into_iter() + .flatten() + .collect(); + + // Try to parse JSON + let mut body_json: serde_json::Value = match serde_json::from_slice(&body_bytes) { + Ok(v) => v, + Err(e) => { + tracing::warn!( + "Failed to parse model list JSON for user_id={}: {}, returning original body", + user.user_id, + e + ); + return build_response( + proxy_response.status, + proxy_response.headers, + Body::from(body_bytes), + ) + .await; + } + }; + + let models_opt = body_json.get_mut("models").and_then(|v| v.as_array_mut()); + + let Some(models_array) = models_opt else { + tracing::debug!("No 'models' array found in model list response, returning original body"); + return Response::builder() + .status(StatusCode::from_u16(proxy_response.status).unwrap_or(StatusCode::OK)) + .header("content-type", "application/json") + .body(Body::from( + serde_json::to_vec(&body_json).unwrap_or_else(|_| body_bytes.to_vec()), + )) + .map_err(|e| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: format!("Failed to build response: {e}"), + }), + ) + .into_response() + }); + }; + + // Collect all model IDs for batch settings lookup + let mut model_ids: Vec = Vec::new(); + for model in models_array.iter() { + if let Some(model_id) = model.get("modelId").and_then(|v| v.as_str()) { + model_ids.push(model_id.to_string()); + } + } + + // Batch fetch settings for all models from the admin models table + let settings_map = state + .model_service + .get_models_by_ids(&model_ids.iter().map(|s| s.as_str()).collect::>()) + .await + .unwrap_or_else(|e| { + tracing::warn!( + "Failed to batch load model settings for model list: {}, defaulting all public={}", + e, + MODEL_PUBLIC_DEFAULT + ); + std::collections::HashMap::new() + }); + + // Attach `public` flag to each model based on its stored settings + let mut decorated_models = Vec::new(); + for mut model in std::mem::take(models_array) { + let public_flag = model + .get("modelId") + .and_then(|v| v.as_str()) + .and_then(|id| settings_map.get(id).map(|m| m.settings.public)) + .unwrap_or(MODEL_PUBLIC_DEFAULT); + + if let Some(obj) = model.as_object_mut() { + obj.insert("public".to_string(), serde_json::Value::Bool(public_flag)); + } + + decorated_models.push(model); + } + + *body_json + .get_mut("models") + .expect("Models key must exist after previous check") = + serde_json::Value::Array(decorated_models); + + let filtered_bytes = serde_json::to_vec(&body_json).unwrap_or_else(|e| { + tracing::error!( + "Failed to serialize filtered model list JSON: {}, returning original body", + e + ); + body_bytes.to_vec() + }); + + Response::builder() + .status(StatusCode::from_u16(proxy_response.status).unwrap_or(StatusCode::OK)) + .header("content-type", "application/json") + .body(Body::from(filtered_bytes)) + .map_err(|e| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: format!("Failed to build response: {e}"), + }), + ) + .into_response() + }) } async fn proxy_signature( diff --git a/crates/api/src/routes/configs.rs b/crates/api/src/routes/configs.rs new file mode 100644 index 00000000..52a62bc4 --- /dev/null +++ b/crates/api/src/routes/configs.rs @@ -0,0 +1,38 @@ +use crate::{error::ApiError, models::*, state::AppState}; +use axum::{extract::State, routing::get, Json, Router}; + +/// Get system configs (requires user authentication, not admin) +#[utoipa::path( + get, + path = "/v1/configs", + tag = "Configs", + responses( + (status = 200, description = "System configs retrieved", body = Option), + (status = 401, description = "Unauthorized", body = crate::error::ApiErrorResponse), + (status = 500, description = "Internal server error", body = crate::error::ApiErrorResponse) + ), + security( + ("session_token" = []) + ) +)] +pub async fn get_system_configs( + State(app_state): State, +) -> Result>, ApiError> { + tracing::info!("Getting system configs"); + + let config = app_state + .system_configs_service + .get_configs() + .await + .map_err(|e| { + tracing::error!(error = ?e, "Failed to get system configs"); + ApiError::internal_server_error("Failed to get system configs") + })?; + + Ok(Json(config.map(Into::into))) +} + +/// Create configs router with all configs routes (requires user authentication) +pub fn create_configs_router() -> Router { + Router::new().route("/v1/configs", get(get_system_configs)) +} diff --git a/crates/api/src/routes/mod.rs b/crates/api/src/routes/mod.rs index 3d6ae0df..37a0f765 100644 --- a/crates/api/src/routes/mod.rs +++ b/crates/api/src/routes/mod.rs @@ -1,6 +1,7 @@ pub mod admin; pub mod api; pub mod attestation; +pub mod configs; pub mod oauth; pub mod users; @@ -114,6 +115,12 @@ pub fn create_router_with_cors(app_state: AppState, cors_config: config::CorsCon let rate_limit_state = RateLimitState::new(); + // Configs routes (requires user authentication, not admin) + let configs_routes = configs::create_configs_router().layer(from_fn_with_state( + auth_state.clone(), + crate::middleware::auth_middleware, + )); + // API proxy routes (requires authentication) let api_routes = api::create_api_router(rate_limit_state).layer(from_fn_with_state( auth_state, @@ -123,6 +130,7 @@ pub fn create_router_with_cors(app_state: AppState, cors_config: config::CorsCon // Build the base router let router = Router::new() .route("/health", get(health_check)) + .merge(configs_routes) // Configs route (requires user auth) .nest("/v1/auth", auth_routes) .nest("/v1/auth", logout_route) // Logout route with auth middleware .nest("/v1/users", user_routes) diff --git a/crates/api/src/state.rs b/crates/api/src/state.rs index 8d6610fe..c0b013d8 100644 --- a/crates/api/src/state.rs +++ b/crates/api/src/state.rs @@ -14,12 +14,29 @@ pub struct NearBalanceCacheEntry { /// Type alias for NEAR balance cache (per-account) pub type NearBalanceCache = Arc>>; +/// Cached model settings entry for a given model_id +#[derive(Debug, Clone)] +pub struct ModelSettingsCacheEntry { + pub last_checked_at: DateTime, + /// Whether the model exists in the admin models table + pub exists: bool, + /// Whether this model is public (visible/usable in responses) + pub public: bool, + /// Optional system-level system prompt for this model + pub system_prompt: Option, +} + +/// Type alias for model settings cache (per-model) +pub type ModelSettingsCache = Arc>>; + /// Application state shared across all handlers #[derive(Clone)] pub struct AppState { pub oauth_service: Arc, pub user_service: Arc, pub user_settings_service: Arc, + pub model_service: Arc, + pub system_configs_service: Arc, pub session_repository: Arc, pub user_repository: Arc, pub proxy_service: Arc, @@ -38,4 +55,6 @@ pub struct AppState { pub near_rpc_url: Url, /// In-memory cache for NEAR account balances to avoid frequent RPC calls pub near_balance_cache: NearBalanceCache, + /// In-memory cache for model settings needed by /v1/responses (public + system_prompt) + pub model_settings_cache: ModelSettingsCache, } diff --git a/crates/api/tests/admin_tests.rs b/crates/api/tests/admin_tests.rs index a3784c72..e07d12f8 100644 --- a/crates/api/tests/admin_tests.rs +++ b/crates/api/tests/admin_tests.rs @@ -1,9 +1,6 @@ mod common; -use common::{create_test_server, create_test_server_with_config, mock_login, TestServerConfig}; -use services::vpc::VpcCredentials; -use wiremock::matchers::{method, path}; -use wiremock::{Mock, MockServer, ResponseTemplate}; +use common::{create_test_server, mock_login}; #[tokio::test] async fn test_admin_users_list_with_admin_account() { @@ -83,218 +80,3 @@ async fn test_admin_users_list_pagination() { assert!(users.len() <= 2, "Should return at most 2 users"); assert!(total >= 4, "Should have at least 4 users total"); } - -#[tokio::test] -async fn test_get_system_prompt_with_non_admin_account() { - let server = create_test_server().await; - - let non_admin_email = "test_user_prompt@no-admin.org"; - let non_admin_token = mock_login(&server, non_admin_email).await; - - let response = server - .get("/v1/admin/system_prompt") - .add_header( - http::HeaderName::from_static("authorization"), - http::HeaderValue::from_str(&format!("Bearer {non_admin_token}")).unwrap(), - ) - .await; - - let status = response.status_code(); - assert_eq!( - status, 403, - "Non-admin should receive 403 Forbidden when trying to get system prompt" - ); - - let body: serde_json::Value = response.json(); - let error = body.get("message").and_then(|v| v.as_str()); - assert_eq!(error, Some("Admin access required")); -} - -#[tokio::test] -async fn test_set_system_prompt_with_non_admin_account() { - let server = create_test_server().await; - - let non_admin_email = "test_user_set_prompt@no-admin.org"; - let non_admin_token = mock_login(&server, non_admin_email).await; - - let response = server - .post("/v1/admin/system_prompt") - .add_header( - http::HeaderName::from_static("authorization"), - http::HeaderValue::from_str(&format!("Bearer {non_admin_token}")).unwrap(), - ) - .add_header( - http::HeaderName::from_static("content-type"), - http::HeaderValue::from_static("application/json"), - ) - .json(&serde_json::json!({ - "system_prompt": "You are a helpful assistant." - })) - .await; - - let status = response.status_code(); - assert_eq!( - status, 403, - "Non-admin should receive 403 Forbidden when trying to set system prompt" - ); - - let body: serde_json::Value = response.json(); - let error = body.get("message").and_then(|v| v.as_str()); - assert_eq!(error, Some("Admin access required")); -} - -#[tokio::test] -async fn test_get_system_prompt_with_admin_success() { - let mock_cloud_api = MockServer::start().await; - - Mock::given(method("GET")) - .and(path("/organizations/test-org-123/settings")) - .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ - "settings": { - "system_prompt": "You are a helpful AI assistant." - } - }))) - .mount(&mock_cloud_api) - .await; - - let vpc_credentials = VpcCredentials { - access_token: "test-access-token".to_string(), - organization_id: "test-org-123".to_string(), - api_key: "test-api-key".to_string(), - }; - - let server = create_test_server_with_config(TestServerConfig { - vpc_credentials: Some(vpc_credentials), - cloud_api_base_url: mock_cloud_api.uri(), - }) - .await; - - let admin_email = "test_admin_get_prompt_success@admin.org"; - let admin_token = mock_login(&server, admin_email).await; - - let response = server - .get("/v1/admin/system_prompt") - .add_header( - http::HeaderName::from_static("authorization"), - http::HeaderValue::from_str(&format!("Bearer {admin_token}")).unwrap(), - ) - .await; - - let status = response.status_code(); - assert_eq!(status, 200, "Admin should be able to get system prompt"); - - let body: serde_json::Value = response.json(); - let system_prompt = body.get("system_prompt").and_then(|v| v.as_str()); - assert_eq!(system_prompt, Some("You are a helpful AI assistant.")); -} - -#[tokio::test] -async fn test_set_system_prompt_with_admin_success() { - let mock_cloud_api = MockServer::start().await; - - Mock::given(method("PATCH")) - .and(path("/organizations/test-org-456/settings")) - .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ - "settings": { - "system_prompt": "You are a helpful assistant." - } - }))) - .mount(&mock_cloud_api) - .await; - - let vpc_credentials = VpcCredentials { - access_token: "test-access-token".to_string(), - organization_id: "test-org-456".to_string(), - api_key: "test-api-key".to_string(), - }; - - let server = create_test_server_with_config(TestServerConfig { - vpc_credentials: Some(vpc_credentials), - cloud_api_base_url: mock_cloud_api.uri(), - }) - .await; - - let admin_email = "test_admin_set_prompt_success@admin.org"; - let admin_token = mock_login(&server, admin_email).await; - - let response = server - .post("/v1/admin/system_prompt") - .add_header( - http::HeaderName::from_static("authorization"), - http::HeaderValue::from_str(&format!("Bearer {admin_token}")).unwrap(), - ) - .add_header( - http::HeaderName::from_static("content-type"), - http::HeaderValue::from_static("application/json"), - ) - .json(&serde_json::json!({ - "system_prompt": "You are a helpful assistant." - })) - .await; - - let status = response.status_code(); - assert_eq!(status, 200, "Admin should be able to set system prompt"); - - let body: serde_json::Value = response.json(); - let system_prompt = body.get("system_prompt").and_then(|v| v.as_str()); - assert_eq!(system_prompt, Some("You are a helpful assistant.")); -} - -#[tokio::test] -async fn test_get_system_prompt_with_admin_but_vpc_not_configured() { - let server = create_test_server().await; - - let admin_email = "test_admin_get_prompt@admin.org"; - let admin_token = mock_login(&server, admin_email).await; - - let response = server - .get("/v1/admin/system_prompt") - .add_header( - http::HeaderName::from_static("authorization"), - http::HeaderValue::from_str(&format!("Bearer {admin_token}")).unwrap(), - ) - .await; - - let status = response.status_code(); - assert_eq!( - status, 500, - "Admin should receive 500 Internal Server Error when VPC is not configured" - ); - - let body: serde_json::Value = response.json(); - let error = body.get("message").and_then(|v| v.as_str()); - assert_eq!(error, Some("Cloud API not configured")); -} - -#[tokio::test] -async fn test_set_system_prompt_with_admin_but_vpc_not_configured() { - let server = create_test_server().await; - - let admin_email = "test_admin_set_prompt@admin.org"; - let admin_token = mock_login(&server, admin_email).await; - - let response = server - .post("/v1/admin/system_prompt") - .add_header( - http::HeaderName::from_static("authorization"), - http::HeaderValue::from_str(&format!("Bearer {admin_token}")).unwrap(), - ) - .add_header( - http::HeaderName::from_static("content-type"), - http::HeaderValue::from_static("application/json"), - ) - .json(&serde_json::json!({ - "system_prompt": "You are a helpful assistant." - })) - .await; - - let status = response.status_code(); - assert_eq!( - status, 500, - "Admin should receive 500 Internal Server Error when VPC is not configured" - ); - - let body: serde_json::Value = response.json(); - let error = body.get("message").and_then(|v| v.as_str()); - assert_eq!(error, Some("Cloud API not configured")); -} diff --git a/crates/api/tests/common.rs b/crates/api/tests/common.rs index 0fecacfc..88c126f4 100644 --- a/crates/api/tests/common.rs +++ b/crates/api/tests/common.rs @@ -55,6 +55,8 @@ pub async fn create_test_server_with_config(test_config: TestServerConfig) -> Te let conversation_repo = db.conversation_repository(); let file_repo = db.file_repository(); let user_settings_repo = db.user_settings_repository(); + let model_repo = db.model_repository(); + let system_configs_repo = db.system_configs_repository(); let near_nonce_repo = db.near_nonce_repository(); // Create services @@ -74,9 +76,18 @@ pub async fn create_test_server_with_config(test_config: TestServerConfig) -> Te let user_service = Arc::new(services::user::UserServiceImpl::new(user_repo.clone())); let user_settings_service = Arc::new(services::user::UserSettingsServiceImpl::new( - user_settings_repo as Arc, + user_settings_repo, )); + let model_service = Arc::new(services::model::service::ModelServiceImpl::new(model_repo)); + + let system_configs_service = Arc::new( + services::system_configs::service::SystemConfigsServiceImpl::new( + system_configs_repo + as Arc, + ), + ); + // Create VPC credentials service based on provided credentials let vpc_credentials_service: Arc = match test_config.vpc_credentials { @@ -119,11 +130,13 @@ pub async fn create_test_server_with_config(test_config: TestServerConfig) -> Te // Create application state let app_state = AppState { - vpc_credentials_service, oauth_service, user_service, user_settings_service, + model_service, + system_configs_service, session_repository: session_repo, + vpc_credentials_service, user_repository: user_repo, proxy_service, conversation_service, @@ -135,6 +148,7 @@ pub async fn create_test_server_with_config(test_config: TestServerConfig) -> Te analytics_service, near_rpc_url: config.near.rpc_url.clone(), near_balance_cache: Arc::new(tokio::sync::RwLock::new(std::collections::HashMap::new())), + model_settings_cache: Arc::new(tokio::sync::RwLock::new(std::collections::HashMap::new())), }; // Create router diff --git a/crates/api/tests/models_tests.rs b/crates/api/tests/models_tests.rs new file mode 100644 index 00000000..30e97e58 --- /dev/null +++ b/crates/api/tests/models_tests.rs @@ -0,0 +1,610 @@ +mod common; + +use common::{create_test_server, mock_login}; +use serde_json::json; + +#[tokio::test] +async fn test_list_models_response_structure() { + let server = create_test_server().await; + + // Use an admin account to access admin endpoints + let admin_email = "test_admin_model_default@admin.org"; + let admin_token = mock_login(&server, admin_email).await; + + let response = server + .get("/v1/admin/models?limit=10&offset=0") + .add_header( + http::HeaderName::from_static("authorization"), + http::HeaderValue::from_str(&format!("Bearer {admin_token}")).unwrap(), + ) + .await; + + let status = response.status_code(); + assert_eq!(status, 200, "Should return 200 when listing models"); + + let body: serde_json::Value = response.json(); + // Verify response structure (don't assert exact count as other tests may have created models) + assert!( + body.get("total").is_some(), + "Response should have total field" + ); + assert!( + body.get("models").is_some(), + "Response should have models array" + ); + assert!( + body.get("limit").is_some(), + "Response should have limit field" + ); + assert!( + body.get("offset").is_some(), + "Response should have offset field" + ); + + let models = body + .get("models") + .and_then(|v| v.as_array()) + .expect("Should have models array"); + let total: i64 = body + .get("total") + .and_then(|v| v.as_i64()) + .expect("Should have total as number"); + + // Verify pagination structure is correct + assert!( + models.len() as i64 <= total, + "Models array length should not exceed total" + ); + assert!( + models.len() as i64 <= 10, // limit is 10 + "Models array length should not exceed limit" + ); +} + +#[tokio::test] +async fn test_model_batch_update_and_list() { + let server = create_test_server().await; + + let admin_email = "test_admin_model_update@admin.org"; + let admin_token = mock_login(&server, admin_email).await; + + // Use a unique model ID to avoid conflicts with other tests + let test_model_id = "test-model-batch-update-and-list"; + + // Batch upsert model with settings.public = true + let batch_body = json!({ + test_model_id: { + "public": true + } + }); + + let response = server + .patch("/v1/admin/models") + .add_header( + http::HeaderName::from_static("authorization"), + http::HeaderValue::from_str(&format!("Bearer {admin_token}")).unwrap(), + ) + .add_header( + http::HeaderName::from_static("content-type"), + http::HeaderValue::from_static("application/json"), + ) + .json(&batch_body) + .await; + + let status = response.status_code(); + assert_eq!(status, 200, "Admin should be able to batch upsert model"); + + let body: serde_json::Value = response.json(); + assert!(body.is_array(), "Response should be an array of models"); + let models = body.as_array().expect("Should be an array"); + assert_eq!(models.len(), 1, "Should return one model"); + let model = &models[0]; + assert_eq!( + model.get("model_id"), + Some(&json!(test_model_id)), + "Response should include correct model_id" + ); + let settings = model.get("settings").expect("Should have settings field"); + assert_eq!( + settings.get("public"), + Some(&json!(true)), + "Public should be true after update" + ); + + // List models to verify persisted settings + let response = server + .get("/v1/admin/models?limit=100&offset=0") + .add_header( + http::HeaderName::from_static("authorization"), + http::HeaderValue::from_str(&format!("Bearer {admin_token}")).unwrap(), + ) + .await; + + let status = response.status_code(); + assert_eq!(status, 200, "Admin should be able to list models"); + + let body: serde_json::Value = response.json(); + let models = body + .get("models") + .and_then(|v| v.as_array()) + .expect("Should have models array"); + + // Find our model in the list (there may be other models from other tests) + let our_model = models + .iter() + .find(|m| m.get("model_id") == Some(&json!(test_model_id))) + .expect("Should find our model in the list"); + + assert_eq!( + our_model.get("model_id"), + Some(&json!(test_model_id)), + "Listed model should have correct model_id" + ); + let settings = our_model + .get("settings") + .expect("Should have settings field"); + assert_eq!( + settings.get("public"), + Some(&json!(true)), + "Public should remain true when listed" + ); +} + +#[tokio::test] +async fn test_list_models_requires_admin() { + let server = create_test_server().await; + + let non_admin_email = "test_user_model@no-admin.org"; + let non_admin_token = mock_login(&server, non_admin_email).await; + + let response = server + .get("/v1/admin/models?limit=10&offset=0") + .add_header( + http::HeaderName::from_static("authorization"), + http::HeaderValue::from_str(&format!("Bearer {non_admin_token}")).unwrap(), + ) + .await; + + let status = response.status_code(); + assert_eq!( + status, 403, + "Non-admin should receive 403 Forbidden when accessing model admin API" + ); + + let body: serde_json::Value = response.json(); + let error = body.get("message").and_then(|v| v.as_str()); + assert_eq!(error, Some("Admin access required")); +} + +#[tokio::test] +async fn test_delete_model_success() { + let server = create_test_server().await; + + let admin_email = "test_admin_model_delete_success@admin.org"; + let admin_token = mock_login(&server, admin_email).await; + + // First, create a model + let batch_body = json!({ + "test-model-delete-1": { + "public": true + } + }); + + let response = server + .patch("/v1/admin/models") + .add_header( + http::HeaderName::from_static("authorization"), + http::HeaderValue::from_str(&format!("Bearer {admin_token}")).unwrap(), + ) + .add_header( + http::HeaderName::from_static("content-type"), + http::HeaderValue::from_static("application/json"), + ) + .json(&batch_body) + .await; + + assert!( + response.status_code().is_success(), + "Admin should be able to batch upsert model before deleting" + ); + + // Then, delete the model + let response = server + .delete("/v1/admin/models/test-model-delete-1") + .add_header( + http::HeaderName::from_static("authorization"), + http::HeaderValue::from_str(&format!("Bearer {admin_token}")).unwrap(), + ) + .await; + + assert_eq!( + response.status_code(), + 204, + "Deleting existing model should return 204 No Content" + ); + + // Verify the model is gone by listing models + let response = server + .get("/v1/admin/models?limit=10&offset=0") + .add_header( + http::HeaderName::from_static("authorization"), + http::HeaderValue::from_str(&format!("Bearer {admin_token}")).unwrap(), + ) + .await; + + assert_eq!( + response.status_code(), + 200, + "Listing models after deletion should return 200" + ); + + let body: serde_json::Value = response.json(); + let models = body + .get("models") + .and_then(|v| v.as_array()) + .expect("Should have models array"); + let model_ids: Vec<&str> = models + .iter() + .filter_map(|m| m.get("model_id").and_then(|v| v.as_str())) + .collect(); + assert!( + !model_ids.contains(&"test-model-delete-1"), + "Deleted model should not appear in the list, got: {model_ids:?}" + ); +} + +#[tokio::test] +async fn test_delete_model_not_found() { + let server = create_test_server().await; + + let admin_email = "test_admin_model_delete_not_found@admin.org"; + let admin_token = mock_login(&server, admin_email).await; + + let response = server + .delete("/v1/admin/models/non-existent-model") + .add_header( + http::HeaderName::from_static("authorization"), + http::HeaderValue::from_str(&format!("Bearer {admin_token}")).unwrap(), + ) + .await; + + assert_eq!( + response.status_code(), + 404, + "Deleting a non-existent model should return 404" + ); + + let body: serde_json::Value = response.json(); + assert_eq!( + body.get("message"), + Some(&json!("Model not found")), + "Error message should indicate model not found" + ); +} + +#[tokio::test] +async fn test_delete_model_requires_admin() { + let server = create_test_server().await; + + let non_admin_email = "test_user_model_delete@no-admin.org"; + let non_admin_token = mock_login(&server, non_admin_email).await; + + let response = server + .delete("/v1/admin/models/test-model-delete-requires-admin") + .add_header( + http::HeaderName::from_static("authorization"), + http::HeaderValue::from_str(&format!("Bearer {non_admin_token}")).unwrap(), + ) + .await; + + assert_eq!( + response.status_code(), + 403, + "Non-admin should receive 403 Forbidden when deleting a model" + ); + + let body: serde_json::Value = response.json(); + let error = body.get("message").and_then(|v| v.as_str()); + assert_eq!(error, Some("Admin access required")); +} + +/// Requests with a model whose settings are non-public should be blocked with 403. +#[tokio::test] +async fn test_responses_block_non_public_model() { + let server = create_test_server().await; + + // Use an admin account to configure model settings + let admin_email = "visibility-non-public-admin@admin.org"; + let admin_token = mock_login(&server, admin_email).await; + + // Explicitly create a non-public model via admin API + let batch_body = json!({ + "test-non-public-model": { + "public": false + } + }); + + let response = server + .patch("/v1/admin/models") + .add_header( + http::HeaderName::from_static("authorization"), + http::HeaderValue::from_str(&format!("Bearer {admin_token}")).unwrap(), + ) + .add_header( + http::HeaderName::from_static("content-type"), + http::HeaderValue::from_static("application/json"), + ) + .json(&batch_body) + .await; + + assert!( + response.status_code().is_success(), + "Admin should be able to set model as non-public" + ); + + // Now send a response request using the non-public model + let body = json!({ + "model": "test-non-public-model", + "input": "Hello" + }); + + let response = server + .post("/v1/responses") + .add_header( + http::HeaderName::from_static("authorization"), + http::HeaderValue::from_str(&format!("Bearer {admin_token}")).unwrap(), + ) + .add_header( + http::HeaderName::from_static("content-type"), + http::HeaderValue::from_static("application/json"), + ) + .json(&body) + .await; + + assert_eq!( + response.status_code(), + 403, + "Requests with non-public model should be blocked with 403" + ); + + let body: serde_json::Value = response.json(); + assert_eq!( + body.get("error").and_then(|v| v.as_str()), + Some("This model is not available"), + "Error message should indicate model is not available" + ); +} + +/// Requests with a model whose settings are public should be allowed (not blocked by 403). +#[tokio::test] +async fn test_responses_allow_public_model() { + let server = create_test_server().await; + + // Use an admin account to configure model settings + let admin_email = "visibility-public-admin@admin.org"; + let admin_token = mock_login(&server, admin_email).await; + + // Mark the model as public via admin API + let batch_body = json!({ + "test-public-model": { + "public": true + } + }); + + let response = server + .patch("/v1/admin/models") + .add_header( + http::HeaderName::from_static("authorization"), + http::HeaderValue::from_str(&format!("Bearer {admin_token}")).unwrap(), + ) + .add_header( + http::HeaderName::from_static("content-type"), + http::HeaderValue::from_static("application/json"), + ) + .json(&batch_body) + .await; + + assert!( + response.status_code().is_success(), + "Admin should be able to set model as public" + ); + + // Now send a response request using the public model + let body = json!({ + "model": "test-public-model", + "input": "Hello" + }); + + let response = server + .post("/v1/responses") + .add_header( + http::HeaderName::from_static("authorization"), + http::HeaderValue::from_str(&format!("Bearer {admin_token}")).unwrap(), + ) + .add_header( + http::HeaderName::from_static("content-type"), + http::HeaderValue::from_static("application/json"), + ) + .json(&body) + .await; + + // We only assert that our visibility check did not block the request with 403. + assert_ne!( + response.status_code(), + 403, + "Requests with public model should not be blocked by visibility check (status was {})", + response.status_code() + ); +} + +/// When a public model has a system_prompt and the client does NOT send instructions, +/// the proxy should inject `instructions = system_prompt` into the forwarded request. +#[tokio::test] +async fn test_responses_injects_system_prompt_when_instructions_missing() { + let server = create_test_server().await; + + // Use an admin account to configure model settings (public + system_prompt) + let admin_email = "system-prompt-no-instructions-admin@admin.org"; + let admin_token = mock_login(&server, admin_email).await; + + let system_prompt = "You are a helpful assistant (model-level)."; + + let batch_body = json!({ + "test-system-prompt-model-1": { + "public": true, + "system_prompt": system_prompt + } + }); + + let response = server + .patch("/v1/admin/models") + .add_header( + http::HeaderName::from_static("authorization"), + http::HeaderValue::from_str(&format!("Bearer {admin_token}")).unwrap(), + ) + .add_header( + http::HeaderName::from_static("content-type"), + http::HeaderValue::from_static("application/json"), + ) + .json(&batch_body) + .await; + + assert!( + response.status_code().is_success(), + "Admin should be able to set model with system_prompt" + ); + + // Now send a responses request WITHOUT instructions + let user_email = "system-prompt-no-instructions-user@example.com"; + let user_token = mock_login(&server, user_email).await; + + let body = json!({ + "model": "test-system-prompt-model-1", + "input": "Hello" + }); + + let response = server + .post("/v1/responses") + .add_header( + http::HeaderName::from_static("authorization"), + http::HeaderValue::from_str(&format!("Bearer {user_token}")).unwrap(), + ) + .add_header( + http::HeaderName::from_static("content-type"), + http::HeaderValue::from_static("application/json"), + ) + .json(&body) + .await; + + // We cannot directly inspect the upstream OpenAI request in this test, + // but at minimum the proxy should not block, since the model is public. + assert_ne!( + response.status_code(), + 403, + "Requests with public model and system_prompt should not be blocked (status was {})", + response.status_code() + ); +} + +/// When a public model has a system_prompt and the client already sends instructions, +/// the proxy should prepend the model system_prompt with two newlines. +#[tokio::test] +async fn test_responses_prepends_system_prompt_when_instructions_present() { + let server = create_test_server().await; + + // Use an admin account to configure model settings (public + system_prompt) + let admin_email = "system-prompt-with-instructions-admin@admin.org"; + let admin_token = mock_login(&server, admin_email).await; + + let system_prompt = "You are a helpful assistant (model-level, prepend)."; + + let batch_body = json!({ + "test-system-prompt-model-2": { + "public": true, + "system_prompt": system_prompt + } + }); + + let response = server + .patch("/v1/admin/models") + .add_header( + http::HeaderName::from_static("authorization"), + http::HeaderValue::from_str(&format!("Bearer {admin_token}")).unwrap(), + ) + .add_header( + http::HeaderName::from_static("content-type"), + http::HeaderValue::from_static("application/json"), + ) + .json(&batch_body) + .await; + + assert!( + response.status_code().is_success(), + "Admin should be able to set model with system_prompt" + ); + + // Now send a responses request WITH client instructions + let user_email = "system-prompt-with-instructions-user@example.com"; + let user_token = mock_login(&server, user_email).await; + + let body = json!({ + "model": "test-system-prompt-model-2", + "instructions": "User provided instructions.", + "input": "Hello" + }); + + let response = server + .post("/v1/responses") + .add_header( + http::HeaderName::from_static("authorization"), + http::HeaderValue::from_str(&format!("Bearer {user_token}")).unwrap(), + ) + .add_header( + http::HeaderName::from_static("content-type"), + http::HeaderValue::from_static("application/json"), + ) + .json(&body) + .await; + + // Similarly, we cannot directly assert the final contents of `instructions` here, + // but we can at least ensure that the request is not blocked by the visibility logic. + assert_ne!( + response.status_code(), + 403, + "Requests with public model and custom instructions should not be blocked (status was {})", + response.status_code() + ); +} + +/// Requests without a `model` field should be allowed (no 403 from visibility logic). +#[tokio::test] +async fn test_responses_allow_without_model_field() { + let server = create_test_server().await; + + let token = mock_login(&server, "visibility-no-model@example.com").await; + + // No `model` field in body + let body = json!({ + "input": "Hello" + }); + + let response = server + .post("/v1/responses") + .add_header( + http::HeaderName::from_static("authorization"), + http::HeaderValue::from_str(&format!("Bearer {token}")).unwrap(), + ) + .add_header( + http::HeaderName::from_static("content-type"), + http::HeaderValue::from_static("application/json"), + ) + .json(&body) + .await; + + assert_ne!( + response.status_code(), + 403, + "Requests without model field should not be blocked by visibility check (status was {})", + response.status_code() + ); +} diff --git a/crates/api/tests/system_configs_tests.rs b/crates/api/tests/system_configs_tests.rs new file mode 100644 index 00000000..92a3ef92 --- /dev/null +++ b/crates/api/tests/system_configs_tests.rs @@ -0,0 +1,223 @@ +mod common; + +use common::{create_test_server, mock_login}; +use serde_json::json; + +#[tokio::test] +async fn test_upsert_system_configs_and_get() { + let server = create_test_server().await; + + let admin_email = "test_admin_configs_upsert@admin.org"; + let admin_token = mock_login(&server, admin_email).await; + + // Upsert system configs with a default_model value + let upsert_body = json!({ + "default_model": "test-default-model-1" + }); + + let response = server + .patch("/v1/admin/configs") + .add_header( + http::HeaderName::from_static("authorization"), + http::HeaderValue::from_str(&format!("Bearer {admin_token}")).unwrap(), + ) + .add_header( + http::HeaderName::from_static("content-type"), + http::HeaderValue::from_static("application/json"), + ) + .json(&upsert_body) + .await; + + assert!( + response.status_code().is_success(), + "Admin should be able to upsert system configs" + ); + + let body: serde_json::Value = response.json(); + assert_eq!( + body.get("default_model"), + Some(&json!("test-default-model-1")), + "Upserted configs should contain correct default_model" + ); + + // Get system configs to verify it was persisted (requires user auth) + let response = server + .get("/v1/configs") + .add_header( + http::HeaderName::from_static("authorization"), + http::HeaderValue::from_str(&format!("Bearer {admin_token}")).unwrap(), + ) + .await; + + assert_eq!( + response.status_code(), + 200, + "Admin should be able to get system configs" + ); + + let body: serde_json::Value = response.json(); + assert!( + body.is_object(), + "System configs GET after upsert should return an object, got: {body:?}" + ); + assert_eq!( + body.get("default_model"), + Some(&json!("test-default-model-1")), + "Fetched system configs should contain correct default_model" + ); +} + +#[tokio::test] +async fn test_update_system_configs() { + let server = create_test_server().await; + + let admin_email = "test_admin_configs_update@admin.org"; + let admin_token = mock_login(&server, admin_email).await; + + // Ensure a known initial config via upsert + let upsert_body = json!({ + "default_model": "initial-model" + }); + + let response = server + .patch("/v1/admin/configs") + .add_header( + http::HeaderName::from_static("authorization"), + http::HeaderValue::from_str(&format!("Bearer {admin_token}")).unwrap(), + ) + .add_header( + http::HeaderName::from_static("content-type"), + http::HeaderValue::from_static("application/json"), + ) + .json(&upsert_body) + .await; + + assert!( + response.status_code().is_success(), + "Admin should be able to upsert initial system configs" + ); + + // Partially update configs + let update_body = json!({ + "default_model": "updated-model" + }); + + let response = server + .patch("/v1/admin/configs") + .add_header( + http::HeaderName::from_static("authorization"), + http::HeaderValue::from_str(&format!("Bearer {admin_token}")).unwrap(), + ) + .add_header( + http::HeaderName::from_static("content-type"), + http::HeaderValue::from_static("application/json"), + ) + .json(&update_body) + .await; + + assert!( + response.status_code().is_success(), + "Admin should be able to update system configs" + ); + + let body: serde_json::Value = response.json(); + assert_eq!( + body.get("default_model"), + Some(&json!("updated-model")), + "Updated system configs should contain new default_model" + ); + + // Verify via GET (requires user auth) + let response = server + .get("/v1/configs") + .add_header( + http::HeaderName::from_static("authorization"), + http::HeaderValue::from_str(&format!("Bearer {admin_token}")).unwrap(), + ) + .await; + + assert_eq!( + response.status_code(), + 200, + "Admin should be able to get updated system configs" + ); + + let body: serde_json::Value = response.json(); + assert_eq!( + body.get("default_model"), + Some(&json!("updated-model")), + "Fetched system configs should reflect updated default_model" + ); +} + +#[tokio::test] +async fn test_get_system_configs_requires_auth() { + let server = create_test_server().await; + + // GET /v1/configs without authentication should return 401 + let response = server.get("/v1/configs").await; + + assert_eq!( + response.status_code(), + 401, + "GET /v1/configs should require user authentication" + ); +} + +#[tokio::test] +async fn test_get_system_configs_allows_non_admin() { + let server = create_test_server().await; + + let non_admin_email = "test_user_configs@no-admin.org"; + let non_admin_token = mock_login(&server, non_admin_email).await; + + // Non-admin user should be able to GET system configs (only write requires admin) + let response = server + .get("/v1/configs") + .add_header( + http::HeaderName::from_static("authorization"), + http::HeaderValue::from_str(&format!("Bearer {non_admin_token}")).unwrap(), + ) + .await; + + assert!( + response.status_code().is_success() || response.status_code() == 200, + "Non-admin users should be able to GET system configs with authentication" + ); +} + +#[tokio::test] +async fn test_system_configs_write_requires_admin() { + let server = create_test_server().await; + + let non_admin_email = "test_user_configs@no-admin.org"; + let non_admin_token = mock_login(&server, non_admin_email).await; + + // Non-admin trying to PATCH system configs should receive 403 + let upsert_body = json!({ + "default_model": "test-model" + }); + + let response = server + .patch("/v1/admin/configs") + .add_header( + http::HeaderName::from_static("authorization"), + http::HeaderValue::from_str(&format!("Bearer {non_admin_token}")).unwrap(), + ) + .add_header( + http::HeaderName::from_static("content-type"), + http::HeaderValue::from_static("application/json"), + ) + .json(&upsert_body) + .await; + + assert_eq!( + response.status_code(), + 403, + "Non-admin should receive 403 Forbidden when writing system configs" + ); + + let body: serde_json::Value = response.json(); + let error = body.get("message").and_then(|v| v.as_str()); + assert_eq!(error, Some("Admin access required")); +} diff --git a/crates/database/src/lib.rs b/crates/database/src/lib.rs index dfab883d..0bfbd81e 100644 --- a/crates/database/src/lib.rs +++ b/crates/database/src/lib.rs @@ -7,8 +7,9 @@ pub mod repositories; pub use pool::DbPool; pub use repositories::{ PostgresAnalyticsRepository, PostgresAppConfigRepository, PostgresConversationRepository, - PostgresFileRepository, PostgresNearNonceRepository, PostgresOAuthRepository, - PostgresSessionRepository, PostgresUserRepository, PostgresUserSettingsRepository, + PostgresFileRepository, PostgresModelRepository, PostgresNearNonceRepository, + PostgresOAuthRepository, PostgresSessionRepository, PostgresSystemConfigsRepository, + PostgresUserRepository, PostgresUserSettingsRepository, }; use crate::pool::create_pool_with_native_tls; @@ -28,9 +29,11 @@ pub struct Database { conversation_repository: Arc, file_repository: Arc, user_settings_repository: Arc, + system_configs_repository: Arc, app_config_repository: Arc, near_nonce_repository: Arc, analytics_repository: Arc, + model_repository: Arc, cluster_manager: Option>, } @@ -43,9 +46,12 @@ impl Database { let conversation_repository = Arc::new(PostgresConversationRepository::new(pool.clone())); let file_repository = Arc::new(PostgresFileRepository::new(pool.clone())); let user_settings_repository = Arc::new(PostgresUserSettingsRepository::new(pool.clone())); + let system_configs_repository = + Arc::new(PostgresSystemConfigsRepository::new(pool.clone())); let app_config_repository = Arc::new(PostgresAppConfigRepository::new(pool.clone())); let near_nonce_repository = Arc::new(PostgresNearNonceRepository::new(pool.clone())); let analytics_repository = Arc::new(PostgresAnalyticsRepository::new(pool.clone())); + let model_repository = Arc::new(PostgresModelRepository::new(pool.clone())); Self { pool, @@ -55,9 +61,11 @@ impl Database { conversation_repository, file_repository, user_settings_repository, + system_configs_repository, app_config_repository, near_nonce_repository, analytics_repository, + model_repository, cluster_manager: None, } } @@ -226,4 +234,14 @@ impl Database { pub fn analytics_repository(&self) -> Arc { self.analytics_repository.clone() } + + /// Get the model settings repository + pub fn model_repository(&self) -> Arc { + self.model_repository.clone() + } + + /// Get the system configs repository + pub fn system_configs_repository(&self) -> Arc { + self.system_configs_repository.clone() + } } diff --git a/crates/database/src/migrations/sql/V13__add_models.sql b/crates/database/src/migrations/sql/V13__add_models.sql new file mode 100644 index 00000000..46574938 --- /dev/null +++ b/crates/database/src/migrations/sql/V13__add_models.sql @@ -0,0 +1,16 @@ +-- Create models table +-- This table stores admin-level model settings per model as JSONB to allow flexible schema evolution +CREATE TABLE models ( + id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), + model_id TEXT NOT NULL UNIQUE, + settings JSONB NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +-- Trigger for updating updated_at timestamp +CREATE TRIGGER update_models_updated_at + BEFORE UPDATE ON models + FOR EACH ROW + EXECUTE FUNCTION update_updated_at_column(); + diff --git a/crates/database/src/migrations/sql/V14__add_system_config.sql b/crates/database/src/migrations/sql/V14__add_system_config.sql new file mode 100644 index 00000000..c21e8979 --- /dev/null +++ b/crates/database/src/migrations/sql/V14__add_system_config.sql @@ -0,0 +1,15 @@ +-- Table for storing application-wide JSONB configuration +CREATE TABLE system_configs ( + id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), + key TEXT NOT NULL UNIQUE, + value JSONB NOT NULL DEFAULT '{}'::jsonb, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +-- Trigger for updating updated_at timestamp +CREATE TRIGGER update_system_configs_updated_at + BEFORE UPDATE ON system_configs + FOR EACH ROW + EXECUTE FUNCTION update_updated_at_column(); + diff --git a/crates/database/src/repositories/mod.rs b/crates/database/src/repositories/mod.rs index 581a8252..7fa55f42 100644 --- a/crates/database/src/repositories/mod.rs +++ b/crates/database/src/repositories/mod.rs @@ -2,9 +2,11 @@ pub mod analytics_repository; pub mod app_config_repository; pub mod conversation_repository; pub mod file_repository; +pub mod model_repository; pub mod near_nonce_repository; pub mod oauth_repository; pub mod session_repository; +pub mod system_configs_repository; pub mod user_repository; pub mod user_settings_repository; @@ -12,8 +14,10 @@ pub use analytics_repository::PostgresAnalyticsRepository; pub use app_config_repository::PostgresAppConfigRepository; pub use conversation_repository::PostgresConversationRepository; pub use file_repository::PostgresFileRepository; +pub use model_repository::PostgresModelRepository; pub use near_nonce_repository::PostgresNearNonceRepository; pub use oauth_repository::PostgresOAuthRepository; pub use session_repository::PostgresSessionRepository; +pub use system_configs_repository::PostgresSystemConfigsRepository; pub use user_repository::PostgresUserRepository; pub use user_settings_repository::PostgresUserSettingsRepository; diff --git a/crates/database/src/repositories/model_repository.rs b/crates/database/src/repositories/model_repository.rs new file mode 100644 index 00000000..63c4c230 --- /dev/null +++ b/crates/database/src/repositories/model_repository.rs @@ -0,0 +1,260 @@ +use crate::pool::DbPool; +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use services::model::ports::{ + Model, ModelSettings, ModelsRepository, PartialModelSettings, UpdateModelParams, + UpsertModelParams, +}; +use tokio_postgres::Row; +use uuid::Uuid; + +pub struct PostgresModelRepository { + pool: DbPool, +} + +impl PostgresModelRepository { + pub fn new(pool: DbPool) -> Self { + Self { pool } + } +} + +#[async_trait] +impl ModelsRepository for PostgresModelRepository { + async fn get_model(&self, model_id: &str) -> anyhow::Result> { + tracing::debug!( + "Repository: Fetching model settings for model_id={}", + model_id + ); + + let client = self.pool.get().await?; + + let row = client + .query_opt( + "SELECT id, model_id, settings, created_at, updated_at + FROM models + WHERE model_id = $1", + &[&model_id], + ) + .await?; + + if let Some(row) = row { + let settings_json: serde_json::Value = row.get("settings"); + + let default_settings = ModelSettings::default(); + // Missing fields will be filled from default settings values + let settings_delta = serde_json::from_value::(settings_json)?; + let settings = default_settings.into_updated(settings_delta); + + Ok(Some(Model { + id: row.get("id"), + model_id: row.get("model_id"), + settings, + created_at: row.get("created_at"), + updated_at: row.get("updated_at"), + })) + } else { + Ok(None) + } + } + + async fn list_models(&self, limit: i64, offset: i64) -> anyhow::Result<(Vec, i64)> { + tracing::debug!( + "Repository: Listing models with limit={}, offset={}", + limit, + offset + ); + + let client = self.pool.get().await?; + + // Get total count + let count_row = client + .query_one("SELECT COUNT(*) as count FROM models", &[]) + .await?; + let total: i64 = count_row.get("count"); + + // Get paginated models + let rows = client + .query( + "SELECT id, model_id, settings, created_at, updated_at + FROM models + ORDER BY model_id ASC + LIMIT $1 OFFSET $2", + &[&limit, &offset], + ) + .await?; + + let mut models = Vec::new(); + for row in rows { + let settings_json: serde_json::Value = row.get("settings"); + + let default_settings = ModelSettings::default(); + let settings_delta = serde_json::from_value::(settings_json)?; + let settings = default_settings.into_updated(settings_delta); + + models.push(Model { + id: row.get("id"), + model_id: row.get("model_id"), + settings, + created_at: row.get("created_at"), + updated_at: row.get("updated_at"), + }); + } + + Ok((models, total)) + } + + async fn get_models_by_ids( + &self, + model_ids: &[&str], + ) -> anyhow::Result> { + tracing::debug!( + "Repository: Fetching models for {} model_ids", + model_ids.len() + ); + + if model_ids.is_empty() { + return Ok(std::collections::HashMap::new()); + } + + let client = self.pool.get().await?; + + let rows = client + .query( + "SELECT id, model_id, settings, created_at, updated_at + FROM models + WHERE model_id = ANY($1)", + &[&model_ids], + ) + .await?; + + let mut map = std::collections::HashMap::new(); + + for row in rows { + let id: Uuid = row.get("id"); + let model_id: String = row.get("model_id"); + let created_at: DateTime = row.get("created_at"); + let updated_at: DateTime = row.get("updated_at"); + + let settings = load_settings_from_row(&row)?; + + let model = Model { + id, + model_id: model_id.clone(), + settings, + created_at, + updated_at, + }; + + map.insert(model_id, model); + } + + Ok(map) + } + + async fn upsert_model(&self, params: UpsertModelParams) -> anyhow::Result { + tracing::info!( + "Repository: Upserting model for model_id={}", + params.model_id + ); + + let client = self.pool.get().await?; + + let settings = serde_json::to_value(params.settings.clone())?; + + // Insert or update by model_id + let row = client + .query_one( + "INSERT INTO models (model_id, settings) + VALUES ($1, $2) + ON CONFLICT (model_id) + DO UPDATE SET settings = EXCLUDED.settings, updated_at = NOW() + RETURNING *", + &[¶ms.model_id, &settings], + ) + .await?; + + let model_id = params.model_id; + + let settings = Model { + id: row.get("id"), + model_id: model_id.clone(), + settings: params.settings, + created_at: row.get("created_at"), + updated_at: row.get("updated_at"), + }; + + tracing::info!( + "Repository: Model settings upserted successfully for model_id={}", + model_id + ); + + Ok(settings) + } + + async fn update_model(&self, params: UpdateModelParams) -> anyhow::Result { + tracing::info!( + "Repository: Updating model for model_id={}", + params.model_id + ); + + let existing_model = self.get_model(¶ms.model_id).await?; + + let Some(existing_model) = existing_model else { + anyhow::bail!("Model not found for model id: {}", params.model_id); + }; + + let client = self.pool.get().await?; + + // Merge with incoming partial settings (if provided) + let new_settings = if let Some(delta) = params.settings { + existing_model.settings.into_updated(delta) + } else { + existing_model.settings + }; + + let new_settings_json = serde_json::to_value(new_settings.clone())?; + + // Persist updated settings + let row = client + .query_one( + "UPDATE models + SET settings = $1 + WHERE model_id = $2 + RETURNING id, model_id, created_at, updated_at", + &[&new_settings_json, ¶ms.model_id], + ) + .await?; + + Ok(Model { + id: row.get("id"), + model_id: row.get("model_id"), + settings: new_settings, + created_at: row.get("created_at"), + updated_at: row.get("updated_at"), + }) + } + + async fn delete_model(&self, model_id: &str) -> anyhow::Result { + tracing::info!("Repository: Deleting model for model_id={}", model_id); + + let client = self.pool.get().await?; + + let rows_affected = client + .execute( + "DELETE FROM models + WHERE model_id = $1", + &[&model_id], + ) + .await?; + + Ok(rows_affected > 0) + } +} + +fn load_settings_from_row(row: &Row) -> anyhow::Result { + let settings_json: serde_json::Value = row.get("settings"); + let default_settings = ModelSettings::default(); + let partial_settings = serde_json::from_value::(settings_json)?; + let settings = default_settings.into_updated(partial_settings); + Ok(settings) +} diff --git a/crates/database/src/repositories/system_configs_repository.rs b/crates/database/src/repositories/system_configs_repository.rs new file mode 100644 index 00000000..cdb3b643 --- /dev/null +++ b/crates/database/src/repositories/system_configs_repository.rs @@ -0,0 +1,79 @@ +use crate::pool::DbPool; +use async_trait::async_trait; +use services::system_configs::ports::{ + PartialSystemConfigs, SystemConfigs, SystemConfigsRepository, SystemKey, +}; + +pub struct PostgresSystemConfigsRepository { + pool: DbPool, +} + +impl PostgresSystemConfigsRepository { + pub fn new(pool: DbPool) -> Self { + Self { pool } + } +} + +#[async_trait] +impl SystemConfigsRepository for PostgresSystemConfigsRepository { + async fn get_configs(&self) -> anyhow::Result> { + tracing::debug!("Repository: Fetching system configs"); + + let client = self.pool.get().await?; + let key = SystemKey::Config.to_string(); + + let row = client + .query_opt("SELECT value FROM system_configs WHERE key = $1", &[&key]) + .await?; + + if let Some(row) = row { + let value_json: serde_json::Value = row.get("value"); + + let default_config = SystemConfigs::default(); + // Missing fields will be filled from default config values + let partial = serde_json::from_value::(value_json)?; + let config = default_config.into_updated(partial); + + Ok(Some(config)) + } else { + Ok(None) + } + } + + async fn upsert_configs(&self, configs: SystemConfigs) -> anyhow::Result { + tracing::info!("Repository: Upserting system configs"); + + let client = self.pool.get().await?; + let key = SystemKey::Config.to_string(); + let value_json = serde_json::to_value(&configs)?; + + client + .execute( + "INSERT INTO system_configs (key, value) + VALUES ($1, $2) + ON CONFLICT (key) + DO UPDATE SET value = $2, updated_at = NOW()", + &[&key, &value_json], + ) + .await?; + + tracing::info!("Repository: System configs upserted successfully"); + + Ok(configs) + } + + async fn update_configs(&self, configs: PartialSystemConfigs) -> anyhow::Result { + tracing::info!("Repository: Updating system configs"); + + // Load existing config, then merge with incoming partial + let existing = self.get_configs().await?; + let Some(existing) = existing else { + anyhow::bail!("System configs not found for key: {}", SystemKey::Config); + }; + + let merged = existing.into_updated(configs); + + // Reuse upsert logic to persist merged result + self.upsert_configs(merged).await + } +} diff --git a/crates/services/src/consts.rs b/crates/services/src/consts.rs new file mode 100644 index 00000000..801755db --- /dev/null +++ b/crates/services/src/consts.rs @@ -0,0 +1,2 @@ +/// Default value for model public visibility +pub const MODEL_PUBLIC_DEFAULT: bool = true; diff --git a/crates/services/src/lib.rs b/crates/services/src/lib.rs index 1ab1abc0..4e29e52e 100644 --- a/crates/services/src/lib.rs +++ b/crates/services/src/lib.rs @@ -1,9 +1,12 @@ pub mod analytics; pub mod auth; +pub mod consts; pub mod conversation; pub mod file; pub mod metrics; +pub mod model; pub mod response; +pub mod system_configs; pub mod types; pub mod user; pub mod vpc; diff --git a/crates/services/src/model/mod.rs b/crates/services/src/model/mod.rs new file mode 100644 index 00000000..83408c78 --- /dev/null +++ b/crates/services/src/model/mod.rs @@ -0,0 +1,2 @@ +pub mod ports; +pub mod service; diff --git a/crates/services/src/model/ports.rs b/crates/services/src/model/ports.rs new file mode 100644 index 00000000..ec6173ae --- /dev/null +++ b/crates/services/src/model/ports.rs @@ -0,0 +1,124 @@ +use async_trait::async_trait; +use chrono::{DateTime, Utc}; + +use crate::consts::MODEL_PUBLIC_DEFAULT; + +/// Model settings content structure +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct ModelSettings { + /// Whether models are public (visible/usable in responses) + pub public: bool, + /// Optional system-level system prompt for this model + #[serde(skip_serializing_if = "Option::is_none")] + pub system_prompt: Option, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct PartialModelSettings { + pub public: Option, + pub system_prompt: Option, +} + +impl Default for ModelSettings { + /// Default model settings. + /// + /// By default, models are **public** (public = true). + fn default() -> Self { + Self { + public: MODEL_PUBLIC_DEFAULT, + system_prompt: None, + } + } +} + +impl ModelSettings { + pub fn into_updated(self, settings: PartialModelSettings) -> Self { + Self { + public: settings.public.unwrap_or(self.public), + system_prompt: settings.system_prompt.or(self.system_prompt), + } + } +} + +/// Model settings stored as JSONB in the database +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct Model { + pub id: uuid::Uuid, + pub model_id: String, + pub settings: ModelSettings, + pub created_at: DateTime, + pub updated_at: DateTime, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct UpsertModelParams { + pub model_id: String, + pub settings: ModelSettings, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct UpdateModelParams { + pub model_id: String, + pub settings: Option, +} + +/// Repository trait for model settings operations +#[async_trait] +pub trait ModelsRepository: Send + Sync { + /// Get settings for a specific model. + /// Returns `Ok(None)` if no settings exist yet for that model. + async fn get_model(&self, model_id: &str) -> anyhow::Result>; + + /// List all models with pagination. + /// Returns a tuple of (models, total_count). + async fn list_models(&self, limit: i64, offset: i64) -> anyhow::Result<(Vec, i64)>; + + /// Batch get full model records for multiple model IDs. + /// Returns a map from model_id to resolved `ModelSettings`. + async fn get_models_by_ids( + &self, + model_ids: &[&str], + ) -> anyhow::Result>; + + /// Create or update a specific model with settings. + async fn upsert_model(&self, params: UpsertModelParams) -> anyhow::Result; + + /// Partially update an existing model (model_id + optional partial settings). + async fn update_model(&self, params: UpdateModelParams) -> anyhow::Result; + + /// Delete a specific model by its identifier. + /// + /// Returns `Ok(true)` if a model was deleted, or `Ok(false)` if no model + /// with the given `model_id` existed. + async fn delete_model(&self, model_id: &str) -> anyhow::Result; +} + +/// Service trait for model settings operations +#[async_trait] +pub trait ModelService: Send + Sync { + /// Get model + async fn get_model(&self, model_id: &str) -> anyhow::Result>; + + /// List all models with pagination. + /// Returns a tuple of (models, total_count). + async fn list_models(&self, limit: i64, offset: i64) -> anyhow::Result<(Vec, i64)>; + + /// Batch get settings content for multiple models. + /// Missing models will not appear in the map; callers should fall back to defaults. + async fn get_models_by_ids( + &self, + model_ids: &[&str], + ) -> anyhow::Result>; + + /// Fully update model + async fn upsert_model(&self, params: UpsertModelParams) -> anyhow::Result; + + /// Partially update model settings for a specific model. + async fn update_model(&self, params: UpdateModelParams) -> anyhow::Result; + + /// Delete a specific model by its identifier. + /// + /// Returns `Ok(true)` if a model was deleted, or `Ok(false)` if no model + /// with the given `model_id` existed. + async fn delete_model(&self, model_id: &str) -> anyhow::Result; +} diff --git a/crates/services/src/model/service.rs b/crates/services/src/model/service.rs new file mode 100644 index 00000000..50c4a1e5 --- /dev/null +++ b/crates/services/src/model/service.rs @@ -0,0 +1,54 @@ +use async_trait::async_trait; +use std::sync::Arc; + +use super::ports::{Model, ModelService, ModelsRepository, UpdateModelParams, UpsertModelParams}; + +pub struct ModelServiceImpl { + repository: Arc, +} + +impl ModelServiceImpl { + pub fn new(repository: Arc) -> Self { + Self { repository } + } +} + +#[async_trait] +impl ModelService for ModelServiceImpl { + async fn get_model(&self, model_id: &str) -> anyhow::Result> { + tracing::info!("Getting model settings for model_id={}", model_id); + + self.repository.get_model(model_id).await + } + + async fn list_models(&self, limit: i64, offset: i64) -> anyhow::Result<(Vec, i64)> { + tracing::info!("Listing models with limit={}, offset={}", limit, offset); + + self.repository.list_models(limit, offset).await + } + + async fn get_models_by_ids( + &self, + model_ids: &[&str], + ) -> anyhow::Result> { + self.repository.get_models_by_ids(model_ids).await + } + + async fn upsert_model(&self, params: UpsertModelParams) -> anyhow::Result { + tracing::info!("Upserting model for model_id={}", params.model_id); + + self.repository.upsert_model(params).await + } + + async fn update_model(&self, params: UpdateModelParams) -> anyhow::Result { + tracing::info!("Updating model for model_id={}", params.model_id); + + self.repository.update_model(params).await + } + + async fn delete_model(&self, model_id: &str) -> anyhow::Result { + tracing::info!("Deleting model for model_id={}", model_id); + + self.repository.delete_model(model_id).await + } +} diff --git a/crates/services/src/system_configs/mod.rs b/crates/services/src/system_configs/mod.rs new file mode 100644 index 00000000..83408c78 --- /dev/null +++ b/crates/services/src/system_configs/mod.rs @@ -0,0 +1,2 @@ +pub mod ports; +pub mod service; diff --git a/crates/services/src/system_configs/ports.rs b/crates/services/src/system_configs/ports.rs new file mode 100644 index 00000000..358a4a2a --- /dev/null +++ b/crates/services/src/system_configs/ports.rs @@ -0,0 +1,74 @@ +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +use std::fmt; + +/// Key for `system_configs` table entries +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub enum SystemKey { + /// Application-wide configuration + Config, +} + +impl fmt::Display for SystemKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + SystemKey::Config => write!(f, "config"), + } + } +} + +/// Application-wide configuration stored in `system_configs` table +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SystemConfigs { + /// Default model identifier to use when not specified + #[serde(skip_serializing_if = "Option::is_none")] + pub default_model: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct PartialSystemConfigs { + pub default_model: Option, +} + +#[allow(clippy::derivable_impls)] +impl Default for SystemConfigs { + fn default() -> Self { + Self { + default_model: None, + } + } +} + +impl SystemConfigs { + pub fn into_updated(self, partial: PartialSystemConfigs) -> Self { + Self { + default_model: partial.default_model.or(self.default_model), + } + } +} + +/// Repository trait for accessing system configs +#[async_trait] +pub trait SystemConfigsRepository: Send + Sync { + /// Get system configs (if exists) + async fn get_configs(&self) -> anyhow::Result>; + + /// Create or update system configs (full replace) + async fn upsert_configs(&self, configs: SystemConfigs) -> anyhow::Result; + + /// Partially update system configs + async fn update_configs(&self, configs: PartialSystemConfigs) -> anyhow::Result; +} + +/// Service trait for system configs +#[async_trait] +pub trait SystemConfigsService: Send + Sync { + /// Get system configs (if exists) + async fn get_configs(&self) -> anyhow::Result>; + + /// Fully create or replace system configs (upsert) + async fn upsert_configs(&self, configs: SystemConfigs) -> anyhow::Result; + + /// Partially update system configs + async fn update_configs(&self, configs: PartialSystemConfigs) -> anyhow::Result; +} diff --git a/crates/services/src/system_configs/service.rs b/crates/services/src/system_configs/service.rs new file mode 100644 index 00000000..690e724b --- /dev/null +++ b/crates/services/src/system_configs/service.rs @@ -0,0 +1,37 @@ +use async_trait::async_trait; +use std::sync::Arc; + +use super::ports::{ + PartialSystemConfigs, SystemConfigs, SystemConfigsRepository, SystemConfigsService, +}; + +pub struct SystemConfigsServiceImpl { + repository: Arc, +} + +impl SystemConfigsServiceImpl { + pub fn new(repository: Arc) -> Self { + Self { repository } + } +} + +#[async_trait] +impl SystemConfigsService for SystemConfigsServiceImpl { + async fn get_configs(&self) -> anyhow::Result> { + tracing::info!("Getting system configs"); + + self.repository.get_configs().await + } + + async fn upsert_configs(&self, configs: SystemConfigs) -> anyhow::Result { + tracing::info!("Upserting system configs"); + + self.repository.upsert_configs(configs).await + } + + async fn update_configs(&self, configs: PartialSystemConfigs) -> anyhow::Result { + tracing::info!("Partially updating system configs"); + + self.repository.update_configs(configs).await + } +} diff --git a/crates/services/src/vpc/service.rs b/crates/services/src/vpc/service.rs index 4bcaa377..ad171627 100644 --- a/crates/services/src/vpc/service.rs +++ b/crates/services/src/vpc/service.rs @@ -29,13 +29,6 @@ struct VpcOrganization { id: String, } -/// Response from access token refresh endpoint -#[derive(serde::Deserialize)] -struct AccessTokenResponse { - access_token: String, - refresh_token: String, -} - /// Cached credentials with tokens struct CachedCredentials { access_token: String, @@ -122,42 +115,6 @@ impl VpcCredentialsServiceImpl { Ok(login_response) } - /// Refresh access token using refresh token - async fn refresh_access_token( - &self, - config: &VpcAuthConfig, - refresh_token: &str, - ) -> anyhow::Result { - let url = format!( - "{}/users/me/access-tokens", - config.base_url.trim_end_matches('/') - ); - - tracing::debug!("Refreshing access token..."); - - let response = self - .http_client - .post(&url) - .header("Authorization", format!("Bearer {}", refresh_token)) - .send() - .await?; - - if response.status() == reqwest::StatusCode::UNAUTHORIZED { - anyhow::bail!("Refresh token expired"); - } - - if !response.status().is_success() { - let status = response.status(); - let body = response.text().await.unwrap_or_default(); - anyhow::bail!("Token refresh failed with status {}: {}", status, body); - } - - let token_response: AccessTokenResponse = response.json().await?; - tracing::info!("Access token refreshed successfully"); - - Ok(token_response) - } - /// Load credentials from database async fn load_from_db(&self) -> anyhow::Result> { let refresh_token = self.repository.get(VPC_REFRESH_TOKEN_CONFIG_KEY).await?; @@ -254,37 +211,6 @@ impl VpcCredentialsServiceImpl { } } - // If we have a refresh token, try to refresh - if let Some(creds) = cached.as_mut() { - match self - .refresh_access_token(config, &creds.refresh_token) - .await - { - Ok(token_response) => { - creds.access_token = token_response.access_token.clone(); - creds.access_token_created_at = std::time::Instant::now(); - creds.refresh_token = token_response.refresh_token.clone(); - - // Update refresh token in database (it rotates) - self.save_to_db(creds).await; - - return Ok(VpcCredentials { - access_token: token_response.access_token, - organization_id: creds.organization_id.clone(), - api_key: creds.api_key.clone(), - }); - } - Err(e) => { - tracing::warn!( - "Failed to refresh access token, will re-authenticate: {}", - e - ); - // Clear cached to force re-auth - *cached = None; - } - } - } - // No cached credentials or refresh failed - perform full VPC auth tracing::info!("Performing full VPC authentication..."); let login_response = self.vpc_authenticate(config).await?;