Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions crates/api/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,7 @@ impl ApiError {
/// OAuth provider error
pub fn oauth_provider_error(provider: &str) -> Self {
Self::bad_gateway(format!(
"Failed to communicate with {} OAuth provider",
provider
"Failed to communicate with {provider} OAuth provider"
))
}

Expand Down
2 changes: 1 addition & 1 deletion crates/api/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use utoipa_swagger_ui::SwaggerUi;
async fn main() -> anyhow::Result<()> {
// Load .env file if it exists
if let Err(e) = dotenvy::dotenv() {
eprintln!("Warning: Could not load .env file: {}", e);
eprintln!("Warning: Could not load .env file: {e}");
eprintln!("Continuing with environment variables...");
}

Expand Down
67 changes: 46 additions & 21 deletions crates/api/src/middleware/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,49 +39,64 @@ pub async fn auth_middleware(
) -> Result<Response, Response> {
let path = request.uri().path();
let method = request.method().clone();

tracing::info!(
"Auth middleware invoked for {} {}",
method,
path
);

tracing::info!("Auth middleware invoked for {} {}", method, path);

// Try to extract authentication from Authorization header
let auth_header = request
.headers()
.get("authorization")
.and_then(|h| h.to_str().ok());

tracing::debug!("Auth middleware: processing request with auth_header present: {}", auth_header.is_some());

tracing::debug!(
"Auth middleware: processing request with auth_header present: {}",
auth_header.is_some()
);

if let Some(auth_value) = auth_header {
tracing::debug!("Authorization header value prefix: {}...", &auth_value.chars().take(10).collect::<String>());
tracing::debug!(
"Authorization header value prefix: {}...",
&auth_value.chars().take(10).collect::<String>()
);
}

let auth_result = if let Some(auth_value) = auth_header {
if let Some(token) = auth_value.strip_prefix("Bearer ") {
tracing::debug!("Extracted Bearer token, length: {}, prefix: {}...", token.len(), &token.chars().take(8).collect::<String>());
tracing::debug!(
"Extracted Bearer token, length: {}, prefix: {}...",
token.len(),
&token.chars().take(8).collect::<String>()
);

// Validate token format (should start with sess_ and be the right length)
if !token.starts_with("sess_") {
tracing::warn!("Invalid session token format: token does not start with 'sess_'");
return Err(ApiError::invalid_token().into_response());
}

if token.len() != 37 {
tracing::warn!("Invalid session token format: expected length 37, got {}", token.len());
tracing::warn!(
"Invalid session token format: expected length 37, got {}",
token.len()
);
return Err(ApiError::invalid_token().into_response());
}

tracing::debug!("Token format validation passed, proceeding to authenticate");

// Hash the token and look it up
let token_hash = hash_session_token(token);
tracing::debug!("Token hashed, hash prefix: {}...", &token_hash.chars().take(16).collect::<String>());

tracing::debug!(
"Token hashed, hash prefix: {}...",
&token_hash.chars().take(16).collect::<String>()
);

authenticate_session_by_token(&state, token_hash).await
} else {
tracing::warn!("Authorization header does not start with 'Bearer ', header: {}", auth_value);
tracing::warn!(
"Authorization header does not start with 'Bearer ', header: {}",
auth_value
);
Err(ApiError::invalid_auth_header())
}
} else {
Expand Down Expand Up @@ -116,19 +131,29 @@ async fn authenticate_session_by_token(
state: &AuthState,
token_hash: String,
) -> Result<AuthenticatedUser, ApiError> {
tracing::debug!("Authenticating session by token hash: {}...", &token_hash.chars().take(16).collect::<String>());

tracing::debug!(
"Authenticating session by token hash: {}...",
&token_hash.chars().take(16).collect::<String>()
);

// Look up the session by token hash
let session = state
.session_repository
.get_session_by_token_hash(token_hash.clone())
.await
.map_err(|e| {
tracing::error!("Failed to get session from repository for token_hash {}...: {}", &token_hash.chars().take(16).collect::<String>(), e);
tracing::error!(
"Failed to get session from repository for token_hash {}...: {}",
&token_hash.chars().take(16).collect::<String>(),
e
);
ApiError::internal_server_error("Failed to authenticate session")
})?
.ok_or_else(|| {
tracing::warn!("Session not found for token_hash: {}...", &token_hash.chars().take(16).collect::<String>());
tracing::warn!(
"Session not found for token_hash: {}...",
&token_hash.chars().take(16).collect::<String>()
);
ApiError::session_not_found()
})?;

Expand All @@ -151,7 +176,7 @@ async fn authenticate_session_by_token(
);
return Err(ApiError::session_expired());
}

let time_until_expiry = session.expires_at.signed_duration_since(now);
tracing::debug!(
"Session valid for {} more seconds",
Expand Down
24 changes: 11 additions & 13 deletions crates/api/src/routes/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ async fn create_conversation(
(
StatusCode::BAD_REQUEST,
Json(ErrorResponse {
error: format!("Failed to read request body: {}", e),
error: format!("Failed to read request body: {e}"),
}),
)
.into_response()
Expand Down Expand Up @@ -105,7 +105,7 @@ async fn create_conversation(
(
StatusCode::BAD_GATEWAY,
Json(ErrorResponse {
error: format!("OpenAI API error: {}", e),
error: format!("OpenAI API error: {e}"),
}),
)
.into_response()
Expand All @@ -131,7 +131,7 @@ async fn create_conversation(
(
StatusCode::BAD_GATEWAY,
Json(ErrorResponse {
error: format!("Failed to read response: {}", e),
error: format!("Failed to read response: {e}"),
}),
)
.into_response()
Expand All @@ -141,7 +141,7 @@ async fn create_conversation(
.collect();

// If successful, parse response and track conversation
if status >= 200 && status < 300 {
if (200..300).contains(&status) {
tracing::debug!("Parsing successful conversation creation response");

// Decompress if gzipped
Expand Down Expand Up @@ -227,7 +227,7 @@ async fn create_conversation(
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: format!("Failed to build response: {}", e),
error: format!("Failed to build response: {e}"),
}),
)
.into_response()
Expand All @@ -254,7 +254,7 @@ async fn list_conversations(
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: format!("Failed to list conversations: {}", e),
error: format!("Failed to list conversations: {e}"),
}),
)
.into_response()
Expand Down Expand Up @@ -316,7 +316,7 @@ async fn proxy_handler(
path
);

if body_bytes.len() > 0 {
if !body_bytes.is_empty() {
if let Ok(body_str) = std::str::from_utf8(&body_bytes) {
tracing::debug!("Request body content: {}", body_str);
}
Expand Down Expand Up @@ -358,7 +358,7 @@ async fn proxy_handler(
(
StatusCode::BAD_GATEWAY,
Json(ErrorResponse {
error: format!("OpenAI API error: {}", e),
error: format!("OpenAI API error: {e}"),
}),
)
.into_response()
Expand Down Expand Up @@ -387,16 +387,14 @@ async fn proxy_handler(
}

// Convert the stream to an axum Body for streaming support
let stream = proxy_response
.body
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e));
let stream = proxy_response.body.map_err(std::io::Error::other);
let body = Body::from_stream(stream);

response.body(body).map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: format!("Failed to build response: {}", e),
error: format!("Failed to build response: {e}"),
}),
)
.into_response()
Expand All @@ -413,7 +411,7 @@ async fn extract_body_bytes(request: Request) -> Result<Bytes, Response> {
(
StatusCode::BAD_REQUEST,
Json(ErrorResponse {
error: format!("Failed to read request body: {}", e),
error: format!("Failed to read request body: {e}"),
}),
)
.into_response()
Expand Down
9 changes: 5 additions & 4 deletions crates/api/src/routes/oauth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ pub async fn oauth_callback(
);

// The provider is determined from the state stored in the database
// Returns (session, frontend_callback_url)
let (session, frontend_callback) = app_state
// Returns (session, frontend_callback_url, is_new_user)
let (session, frontend_callback, is_new_user) = app_state
.oauth_service
.handle_callback_unified(params.code.clone(), params.state.clone())
.await
Expand Down Expand Up @@ -140,10 +140,11 @@ pub async fn oauth_callback(
tracing::info!("Redirecting to frontend: {}", frontend_url);

let callback_url = format!(
"{}/auth/callback?token={}&expires_at={}",
"{}/auth/callback?token={}&expires_at={}&is_new_user={}",
frontend_url,
urlencoding::encode(&token),
urlencoding::encode(&session.expires_at.to_rfc3339())
urlencoding::encode(&session.expires_at.to_rfc3339()),
is_new_user
);

tracing::debug!("Final callback URL: {}", callback_url);
Expand Down
24 changes: 12 additions & 12 deletions crates/api/tests/e2e_api_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,13 @@ async fn test_conversation_workflow() {
.post("/v1/conversations")
.add_header(
http::HeaderName::from_static("authorization"),
http::HeaderValue::from_str(&format!("Bearer {}", SESSION_TOKEN)).unwrap(),
http::HeaderValue::from_str(&format!("Bearer {SESSION_TOKEN}")).unwrap(),
)
.json(&create_conv_body)
.await;

let status = response.status_code();
println!(" Status: {}", status);
println!(" Status: {status}");

let conversation_id = if status.is_success() {
let body: serde_json::Value = response.json();
Expand All @@ -104,11 +104,11 @@ async fn test_conversation_workflow() {
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.expect("Conversation should have an ID");
println!(" Conversation ID: {}", conv_id);
println!(" Conversation ID: {conv_id}");
conv_id
} else {
let error_text = response.text();
println!(" ✗ Failed: {}", error_text);
println!(" ✗ Failed: {error_text}");
panic!("Failed to create conversation");
};

Expand All @@ -130,13 +130,13 @@ async fn test_conversation_workflow() {
.post("/v1/responses")
.add_header(
http::HeaderName::from_static("authorization"),
http::HeaderValue::from_str(&format!("Bearer {}", SESSION_TOKEN)).unwrap(),
http::HeaderValue::from_str(&format!("Bearer {SESSION_TOKEN}")).unwrap(),
)
.json(&request_body)
.await;

let status = response.status_code();
println!(" Status: {}", status);
println!(" Status: {status}");

if status.is_success() {
let body: serde_json::Value = response.json();
Expand All @@ -147,7 +147,7 @@ async fn test_conversation_workflow() {
);
} else {
let error_text = response.text();
println!(" ✗ Failed: {}", error_text);
println!(" ✗ Failed: {error_text}");
panic!("Failed to create first response");
};

Expand All @@ -169,13 +169,13 @@ async fn test_conversation_workflow() {
.post("/v1/responses")
.add_header(
http::HeaderName::from_static("authorization"),
http::HeaderValue::from_str(&format!("Bearer {}", SESSION_TOKEN)).unwrap(),
http::HeaderValue::from_str(&format!("Bearer {SESSION_TOKEN}")).unwrap(),
)
.json(&request_body)
.await;

let status = response.status_code();
println!(" Status: {}", status);
println!(" Status: {status}");

if status.is_success() {
let body: serde_json::Value = response.json();
Expand All @@ -186,7 +186,7 @@ async fn test_conversation_workflow() {
);
} else {
let error_text = response.text();
println!(" ✗ Failed: {}", error_text);
println!(" ✗ Failed: {error_text}");
panic!("Failed to create second response");
};

Expand All @@ -196,14 +196,14 @@ async fn test_conversation_workflow() {
.get("/v1/conversations")
.add_header(
http::HeaderName::from_static("authorization"),
http::HeaderValue::from_str(&format!("Bearer {}", SESSION_TOKEN)).unwrap(),
http::HeaderValue::from_str(&format!("Bearer {SESSION_TOKEN}")).unwrap(),
)
.await;

assert_eq!(response.status_code(), 200, "Should list conversations");

let conversations: Vec<serde_json::Value> = response.json();
println!("{:?}", conversations);
println!("{conversations:?}");
println!(" Found {} total conversations", conversations.len());

// Find our conversation
Expand Down
2 changes: 1 addition & 1 deletion crates/services/src/auth/ports.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ pub trait OAuthService: Send + Sync {
&self,
code: String,
state: String,
) -> anyhow::Result<(UserSession, Option<String>)>;
) -> anyhow::Result<(UserSession, Option<String>, bool)>;

/// Refresh an access token
async fn refresh_token(
Expand Down
Loading