Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/api/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@ async fn main() -> anyhow::Result<()> {
as Arc<dyn services::subscription::ports::SubscriptionRepository>,
webhook_repo: db.payment_webhook_repository()
as Arc<dyn services::subscription::ports::PaymentWebhookRepository>,
purchased_token_repo: db.purchased_token_repository()
as Arc<dyn services::subscription::ports::PurchasedTokenRepository>,
system_configs_service: system_configs_service.clone()
as Arc<dyn services::system_configs::ports::SystemConfigsService>,
user_repository: user_repo.clone(),
Expand Down
1 change: 1 addition & 0 deletions crates/api/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,7 @@ impl TryFrom<UpsertSystemConfigsRequest> for services::system_configs::ports::Pa
rate_limit,
subscription_plans: req.subscription_plans,
max_instances_per_manager: req.max_instances_per_manager,
..Default::default()
})
}
}
Expand Down
7 changes: 7 additions & 0 deletions crates/api/src/openapi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ use utoipa::OpenApi;
crate::routes::subscriptions::resume_subscription,
crate::routes::subscriptions::list_plans,
crate::routes::subscriptions::list_subscriptions,
crate::routes::subscriptions::create_token_purchase,
crate::routes::subscriptions::get_purchased_token_balance,
crate::routes::subscriptions::get_tokens_purchase_info,
// Admin endpoints
crate::routes::admin::list_users,
crate::routes::admin::list_models,
Expand Down Expand Up @@ -161,6 +164,10 @@ use utoipa::OpenApi;
crate::routes::subscriptions::ResumeSubscriptionResponse,
crate::routes::subscriptions::ListSubscriptionsResponse,
crate::routes::subscriptions::ListPlansResponse,
crate::routes::subscriptions::CreateTokenPurchaseRequest,
crate::routes::subscriptions::CreateTokenPurchaseResponse,
crate::routes::subscriptions::PurchasedTokenBalanceResponse,
crate::routes::subscriptions::TokensPurchaseInfoResponse,
services::subscription::ports::SubscriptionWithPlan,
services::subscription::ports::SubscriptionPlan,
// Attestation models
Expand Down
32 changes: 30 additions & 2 deletions crates/api/src/routes/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2487,7 +2487,8 @@ async fn proxy_responses(
state.user_usage_service.clone(),
state.model_pricing_cache.clone(),
user.user_id,
);
)
.with_subscription_service(state.subscription_service.clone());
if let Some(Extension(api_key)) = api_key_ext {
usage_stream = usage_stream.with_agent_service(state.agent_service.clone(), api_key);
}
Expand Down Expand Up @@ -3540,7 +3541,8 @@ async fn proxy_chat_completions(
state.user_usage_service.clone(),
state.model_pricing_cache.clone(),
user.user_id,
);
)
.with_subscription_service(state.subscription_service.clone());
if let Some(Extension(api_key)) = api_key_ext {
usage_stream = usage_stream.with_agent_service(state.agent_service.clone(), api_key);
}
Expand Down Expand Up @@ -4710,6 +4712,19 @@ async fn record_chat_usage_from_body(
return false;
}

// Debit purchased tokens if usage overflowed monthly limit
if let Err(e) = state
.subscription_service
.debit_purchased_tokens_after_usage(user_id, usage.total_tokens as i64)
.await
{
tracing::warn!(
"Failed to debit purchased tokens for overflow (user_id={}): {}",
user_id,
e
);
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic for debiting purchased tokens is duplicated in record_response_usage_from_body (lines 4798-4809). To improve maintainability and reduce code duplication, consider extracting this block into a separate helper function.

For example, you could create a helper function like this:

async fn debit_overflow_tokens(state: &AppState, user_id: UserId, tokens: i64) {
    if let Err(e) = state
        .subscription_service
        .debit_purchased_tokens_after_usage(user_id, tokens)
        .await
    {
        tracing::warn!(
            "Failed to debit purchased tokens for overflow (user_id={}): {}",
            user_id,
            e
        );
    }
}

Then you could call it from both record_chat_usage_from_body and record_response_usage_from_body:

debit_overflow_tokens(&state, user_id, usage.total_tokens as i64).await;


// Record to agent_usage_log if request was made with an API key
if let Some(api_key) = api_key_ext {
let pricing_opt = state.model_pricing_cache.get_pricing(&request_model).await;
Expand Down Expand Up @@ -4780,6 +4795,19 @@ async fn record_response_usage_from_body(
return false;
}

// Debit purchased tokens if usage overflowed monthly limit
if let Err(e) = state
.subscription_service
.debit_purchased_tokens_after_usage(user_id, usage.total_tokens as i64)
.await
{
tracing::warn!(
"Failed to debit purchased tokens for overflow (user_id={}): {}",
user_id,
e
);
}

// Record to agent_usage_log if request was made with an API key
if let Some(api_key) = api_key_ext {
let pricing_opt = state.model_pricing_cache.get_pricing(&usage.model).await;
Expand Down
165 changes: 165 additions & 0 deletions crates/api/src/routes/subscriptions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,38 @@ pub struct ListPlansResponse {
pub plans: Vec<SubscriptionPlan>,
}

/// Request to create a token purchase checkout
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct CreateTokenPurchaseRequest {
/// URL to redirect after successful checkout
pub success_url: String,
/// URL to redirect after cancelled checkout
pub cancel_url: String,
}

/// Response containing token purchase checkout URL
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct CreateTokenPurchaseResponse {
/// Stripe checkout URL for completing token purchase
pub checkout_url: String,
}

/// Response containing purchased token balance
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct PurchasedTokenBalanceResponse {
/// Current purchased token balance (spendable)
pub balance: i64,
}

/// Response containing token purchase info (for UI display)
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct TokensPurchaseInfoResponse {
/// Tokens per purchase (e.g. 1_000_000)
pub amount: u64,
/// Price per 1M tokens in USD (e.g. 1.70)
pub price_per_million: f64,
}

/// Request to create a customer portal session
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct CreatePortalSessionRequest {
Expand Down Expand Up @@ -212,6 +244,10 @@ pub async fn create_subscription(
tracing::error!("Unexpected MonthlyTokenLimitExceeded in create");
ApiError::internal_server_error("Failed to create subscription")
}
SubscriptionError::TokenPurchaseNotConfigured => {
tracing::error!("Unexpected TokenPurchaseNotConfigured in create");
ApiError::internal_server_error("Failed to create subscription")
}
})?;

Ok(Json(CreateSubscriptionResponse { checkout_url }))
Expand Down Expand Up @@ -483,6 +519,123 @@ pub async fn handle_stripe_webhook(
Ok(Json(serde_json::json!({ "received": true })))
}

/// Create token purchase checkout session
#[utoipa::path(
post,
path = "/v1/subscriptions/tokens/purchase",
tag = "Subscriptions",
request_body = CreateTokenPurchaseRequest,
responses(
(status = 200, description = "Checkout session created successfully", body = CreateTokenPurchaseResponse),
(status = 400, description = "Invalid request", body = crate::error::ApiErrorResponse),
(status = 401, description = "Unauthorized", body = crate::error::ApiErrorResponse),
(status = 503, description = "Token purchase not configured", body = crate::error::ApiErrorResponse)
),
security(
("session_token" = [])
)
)]
pub async fn create_token_purchase(
State(app_state): State<AppState>,
Extension(user): Extension<AuthenticatedUser>,
Json(req): Json<CreateTokenPurchaseRequest>,
) -> Result<Json<CreateTokenPurchaseResponse>, ApiError> {
tracing::info!(
"Creating token purchase checkout for user_id={}",
user.user_id
);

validate_redirect_url(&req.success_url, "success_url")?;
validate_redirect_url(&req.cancel_url, "cancel_url")?;

let checkout_url = app_state
.subscription_service
.create_token_purchase_checkout(user.user_id, req.success_url, req.cancel_url)
.await
.map_err(|e| match e {
SubscriptionError::TokenPurchaseNotConfigured => {
ApiError::service_unavailable("Token purchase is not configured")
}
SubscriptionError::NoStripeCustomer => {
ApiError::not_found("No Stripe customer found for this user")
}
SubscriptionError::StripeError(msg) => {
tracing::error!(error = ?msg, "Stripe error creating token purchase checkout");
ApiError::internal_server_error("Failed to create checkout")
}
_ => {
tracing::error!(error = ?e, "Failed to create token purchase checkout");
ApiError::internal_server_error("Failed to create checkout")
}
})?;

Ok(Json(CreateTokenPurchaseResponse { checkout_url }))
}

/// Get purchased token balance
#[utoipa::path(
get,
path = "/v1/subscriptions/tokens/balance",
tag = "Subscriptions",
responses(
(status = 200, description = "Balance retrieved successfully", body = PurchasedTokenBalanceResponse),
(status = 401, description = "Unauthorized", body = crate::error::ApiErrorResponse)
),
security(
("session_token" = [])
)
)]
pub async fn get_purchased_token_balance(
State(app_state): State<AppState>,
Extension(user): Extension<AuthenticatedUser>,
) -> Result<Json<PurchasedTokenBalanceResponse>, ApiError> {
let balance = app_state
.subscription_service
.get_purchased_token_balance(user.user_id)
.await
.map_err(|e| {
tracing::error!(error = ?e, "Failed to get purchased token balance");
ApiError::internal_server_error("Failed to get balance")
})?;

Ok(Json(PurchasedTokenBalanceResponse { balance }))
}

/// Get token purchase info (amount, price) for UI display
#[utoipa::path(
get,
path = "/v1/subscriptions/tokens/purchase-info",
tag = "Subscriptions",
responses(
(status = 200, description = "Purchase info retrieved successfully", body = TokensPurchaseInfoResponse),
(status = 401, description = "Unauthorized", body = crate::error::ApiErrorResponse),
(status = 404, description = "Token purchase not configured", body = crate::error::ApiErrorResponse)
),
security(
("session_token" = [])
)
)]
pub async fn get_tokens_purchase_info(
State(app_state): State<AppState>,
Extension(_user): Extension<AuthenticatedUser>,
) -> Result<Json<TokensPurchaseInfoResponse>, ApiError> {
let info = app_state
.subscription_service
.get_tokens_purchase_info()
.await
.map_err(|e| {
tracing::error!(error = ?e, "Failed to get tokens purchase info");
ApiError::internal_server_error("Failed to get purchase info")
})?;

let info = info.ok_or_else(|| ApiError::not_found("Token purchase is not configured"))?;

Ok(Json(TokensPurchaseInfoResponse {
amount: info.amount,
price_per_million: info.price_per_million,
}))
}

/// Create subscription router with authenticated routes
pub fn create_subscriptions_router() -> Router<AppState> {
Router::new()
Expand All @@ -491,6 +644,18 @@ pub fn create_subscriptions_router() -> Router<AppState> {
.route("/v1/subscriptions/cancel", post(cancel_subscription))
.route("/v1/subscriptions/resume", post(resume_subscription))
.route("/v1/subscriptions/portal", post(create_portal_session))
.route(
"/v1/subscriptions/tokens/purchase",
post(create_token_purchase),
)
.route(
"/v1/subscriptions/tokens/balance",
get(get_purchased_token_balance),
)
.route(
"/v1/subscriptions/tokens/purchase-info",
get(get_tokens_purchase_info),
)
}

/// Create public subscription router (for webhooks and plans - no auth)
Expand Down
35 changes: 35 additions & 0 deletions crates/api/src/usage_parsing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ pub struct UsageTrackingStreamChatCompletions<S> {
buffer: String,
usage: Option<ParsedUsage>,
user_usage: Arc<dyn services::user_usage::UserUsageService>,
subscription_service: Option<Arc<dyn services::subscription::ports::SubscriptionService>>,
pricing_cache: crate::model_pricing::ModelPricingCache,
user_id: UserId,
// Optional agent service + API key for dual tracking
Expand All @@ -158,13 +159,22 @@ where
buffer: String::new(),
usage: None,
user_usage,
subscription_service: None,
pricing_cache,
user_id,
agent_service: None,
api_key_ext: None,
}
}

pub fn with_subscription_service(
mut self,
subscription_service: Arc<dyn services::subscription::ports::SubscriptionService>,
) -> Self {
self.subscription_service = Some(subscription_service);
self
}

pub fn with_agent_service(
mut self,
agent_service: Arc<dyn services::agent::AgentService>,
Expand Down Expand Up @@ -207,6 +217,7 @@ where
this.usage.take(),
StreamUsageContext {
user_usage: this.user_usage.clone(),
subscription_service: this.subscription_service.clone(),
pricing_cache: this.pricing_cache.clone(),
user_id: this.user_id,
stream_name: "UsageTrackingStreamChatCompletions",
Expand All @@ -232,6 +243,7 @@ pub struct UsageTrackingStreamResponseCompleted<S> {
buffer: String,
usage: Option<ParsedUsage>,
user_usage: Arc<dyn services::user_usage::UserUsageService>,
subscription_service: Option<Arc<dyn services::subscription::ports::SubscriptionService>>,
pricing_cache: crate::model_pricing::ModelPricingCache,
user_id: UserId,
// Optional agent service + API key for dual tracking
Expand All @@ -254,13 +266,22 @@ where
buffer: String::new(),
usage: None,
user_usage,
subscription_service: None,
pricing_cache,
user_id,
agent_service: None,
api_key_ext: None,
}
}

pub fn with_subscription_service(
mut self,
subscription_service: Arc<dyn services::subscription::ports::SubscriptionService>,
) -> Self {
self.subscription_service = Some(subscription_service);
self
}

pub fn with_agent_service(
mut self,
agent_service: Arc<dyn services::agent::AgentService>,
Expand Down Expand Up @@ -300,6 +321,7 @@ where
this.usage.take(),
StreamUsageContext {
user_usage: this.user_usage.clone(),
subscription_service: this.subscription_service.clone(),
pricing_cache: this.pricing_cache.clone(),
user_id: this.user_id,
stream_name: "UsageTrackingStreamResponseCompleted",
Expand All @@ -317,6 +339,7 @@ where

struct StreamUsageContext {
user_usage: Arc<dyn services::user_usage::UserUsageService>,
subscription_service: Option<Arc<dyn services::subscription::ports::SubscriptionService>>,
pricing_cache: crate::model_pricing::ModelPricingCache,
user_id: UserId,
stream_name: &'static str,
Expand Down Expand Up @@ -359,6 +382,18 @@ fn record_usage_on_stream_end(usage: Option<ParsedUsage>, ctx: StreamUsageContex
ctx.user_id,
e
);
} else if let Some(ref sub) = ctx.subscription_service {
// Debit purchased tokens if usage overflowed monthly limit
if let Err(e) = sub
.debit_purchased_tokens_after_usage(ctx.user_id, usage.total_tokens as i64)
.await
{
tracing::warn!(
"Failed to debit purchased tokens for overflow (user_id={}): {}",
ctx.user_id,
e
);
}
}

// ADDITIONALLY: Record to agent_usage_log if API key is present
Expand Down
Loading
Loading