diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 00000000..5cc36bc3 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,75 @@ +name: Tests + +on: + push: + branches: [ main, develop ] + pull_request: + branches: [ main, develop ] + +env: + CARGO_TERM_COLOR: always + RUST_BACKTRACE: 1 + +jobs: + test: + name: Test Suite + runs-on: ubuntu-latest + + services: + postgres: + image: postgres:15 + env: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: chat_api + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + - 5432:5432 + + steps: + - uses: actions/checkout@v4 + + - name: Install Rust + uses: actions-rust-lang/setup-rust-toolchain@v1 + with: + components: rustfmt, clippy + + - name: Cache dependencies + uses: actions/cache@v3 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo- + + - name: Run cargo fmt + run: cargo fmt --all -- --check + + - name: Run cargo clippy + run: cargo clippy --all-targets --all-features -- -D warnings + + - name: Run unit tests + run: cargo test --lib --bins + + - name: Run integration tests + run: cargo test --test e2e_api_tests + env: + DATABASE_HOST: localhost + DATABASE_PORT: 5432 + DATABASE_NAME: chat_api + DATABASE_USERNAME: postgres + DATABASE_PASSWORD: postgres + DATABASE_MAX_CONNECTIONS: 5 + DATABASE_TLS_ENABLED: "false" + RUST_LOG: debug + DEV: "true" + + - name: Build release + run: cargo build --release diff --git a/crates/api/src/error.rs b/crates/api/src/error.rs index 2f0c12a6..e6920b68 100644 --- a/crates/api/src/error.rs +++ b/crates/api/src/error.rs @@ -145,8 +145,7 @@ impl ApiError { /// OAuth provider error pub fn oauth_provider_error(provider: &str) -> Self { Self::bad_gateway(format!( - "Failed to communicate with {} OAuth provider", - provider + "Failed to communicate with {provider} OAuth provider" )) } diff --git a/crates/api/src/main.rs b/crates/api/src/main.rs index abb2727a..ce80d1cd 100644 --- a/crates/api/src/main.rs +++ b/crates/api/src/main.rs @@ -12,7 +12,7 @@ use utoipa_swagger_ui::SwaggerUi; async fn main() -> anyhow::Result<()> { // Load .env file if it exists if let Err(e) = dotenvy::dotenv() { - eprintln!("Warning: Could not load .env file: {}", e); + eprintln!("Warning: Could not load .env file: {e}"); eprintln!("Continuing with environment variables..."); } diff --git a/crates/api/src/middleware/auth.rs b/crates/api/src/middleware/auth.rs index a5cd0c65..c72688b4 100644 --- a/crates/api/src/middleware/auth.rs +++ b/crates/api/src/middleware/auth.rs @@ -39,12 +39,8 @@ pub async fn auth_middleware( ) -> Result { let path = request.uri().path(); let method = request.method().clone(); - - tracing::info!( - "Auth middleware invoked for {} {}", - method, - path - ); + + tracing::info!("Auth middleware invoked for {} {}", method, path); // Try to extract authentication from Authorization header let auth_header = request @@ -52,36 +48,55 @@ pub async fn auth_middleware( .get("authorization") .and_then(|h| h.to_str().ok()); - tracing::debug!("Auth middleware: processing request with auth_header present: {}", auth_header.is_some()); - + tracing::debug!( + "Auth middleware: processing request with auth_header present: {}", + auth_header.is_some() + ); + if let Some(auth_value) = auth_header { - tracing::debug!("Authorization header value prefix: {}...", &auth_value.chars().take(10).collect::()); + tracing::debug!( + "Authorization header value prefix: {}...", + &auth_value.chars().take(10).collect::() + ); } let auth_result = if let Some(auth_value) = auth_header { if let Some(token) = auth_value.strip_prefix("Bearer ") { - tracing::debug!("Extracted Bearer token, length: {}, prefix: {}...", token.len(), &token.chars().take(8).collect::()); + tracing::debug!( + "Extracted Bearer token, length: {}, prefix: {}...", + token.len(), + &token.chars().take(8).collect::() + ); // Validate token format (should start with sess_ and be the right length) if !token.starts_with("sess_") { tracing::warn!("Invalid session token format: token does not start with 'sess_'"); return Err(ApiError::invalid_token().into_response()); } - + if token.len() != 37 { - tracing::warn!("Invalid session token format: expected length 37, got {}", token.len()); + tracing::warn!( + "Invalid session token format: expected length 37, got {}", + token.len() + ); return Err(ApiError::invalid_token().into_response()); } tracing::debug!("Token format validation passed, proceeding to authenticate"); - + // Hash the token and look it up let token_hash = hash_session_token(token); - tracing::debug!("Token hashed, hash prefix: {}...", &token_hash.chars().take(16).collect::()); - + tracing::debug!( + "Token hashed, hash prefix: {}...", + &token_hash.chars().take(16).collect::() + ); + authenticate_session_by_token(&state, token_hash).await } else { - tracing::warn!("Authorization header does not start with 'Bearer ', header: {}", auth_value); + tracing::warn!( + "Authorization header does not start with 'Bearer ', header: {}", + auth_value + ); Err(ApiError::invalid_auth_header()) } } else { @@ -116,19 +131,29 @@ async fn authenticate_session_by_token( state: &AuthState, token_hash: String, ) -> Result { - tracing::debug!("Authenticating session by token hash: {}...", &token_hash.chars().take(16).collect::()); - + tracing::debug!( + "Authenticating session by token hash: {}...", + &token_hash.chars().take(16).collect::() + ); + // Look up the session by token hash let session = state .session_repository .get_session_by_token_hash(token_hash.clone()) .await .map_err(|e| { - tracing::error!("Failed to get session from repository for token_hash {}...: {}", &token_hash.chars().take(16).collect::(), e); + tracing::error!( + "Failed to get session from repository for token_hash {}...: {}", + &token_hash.chars().take(16).collect::(), + e + ); ApiError::internal_server_error("Failed to authenticate session") })? .ok_or_else(|| { - tracing::warn!("Session not found for token_hash: {}...", &token_hash.chars().take(16).collect::()); + tracing::warn!( + "Session not found for token_hash: {}...", + &token_hash.chars().take(16).collect::() + ); ApiError::session_not_found() })?; @@ -151,7 +176,7 @@ async fn authenticate_session_by_token( ); return Err(ApiError::session_expired()); } - + let time_until_expiry = session.expires_at.signed_duration_since(now); tracing::debug!( "Session valid for {} more seconds", diff --git a/crates/api/src/routes/api.rs b/crates/api/src/routes/api.rs index 6cbdfe92..b88cbc5d 100644 --- a/crates/api/src/routes/api.rs +++ b/crates/api/src/routes/api.rs @@ -65,7 +65,7 @@ async fn create_conversation( ( StatusCode::BAD_REQUEST, Json(ErrorResponse { - error: format!("Failed to read request body: {}", e), + error: format!("Failed to read request body: {e}"), }), ) .into_response() @@ -105,7 +105,7 @@ async fn create_conversation( ( StatusCode::BAD_GATEWAY, Json(ErrorResponse { - error: format!("OpenAI API error: {}", e), + error: format!("OpenAI API error: {e}"), }), ) .into_response() @@ -131,7 +131,7 @@ async fn create_conversation( ( StatusCode::BAD_GATEWAY, Json(ErrorResponse { - error: format!("Failed to read response: {}", e), + error: format!("Failed to read response: {e}"), }), ) .into_response() @@ -141,7 +141,7 @@ async fn create_conversation( .collect(); // If successful, parse response and track conversation - if status >= 200 && status < 300 { + if (200..300).contains(&status) { tracing::debug!("Parsing successful conversation creation response"); // Decompress if gzipped @@ -227,7 +227,7 @@ async fn create_conversation( ( StatusCode::INTERNAL_SERVER_ERROR, Json(ErrorResponse { - error: format!("Failed to build response: {}", e), + error: format!("Failed to build response: {e}"), }), ) .into_response() @@ -254,7 +254,7 @@ async fn list_conversations( ( StatusCode::INTERNAL_SERVER_ERROR, Json(ErrorResponse { - error: format!("Failed to list conversations: {}", e), + error: format!("Failed to list conversations: {e}"), }), ) .into_response() @@ -316,7 +316,7 @@ async fn proxy_handler( path ); - if body_bytes.len() > 0 { + if !body_bytes.is_empty() { if let Ok(body_str) = std::str::from_utf8(&body_bytes) { tracing::debug!("Request body content: {}", body_str); } @@ -358,7 +358,7 @@ async fn proxy_handler( ( StatusCode::BAD_GATEWAY, Json(ErrorResponse { - error: format!("OpenAI API error: {}", e), + error: format!("OpenAI API error: {e}"), }), ) .into_response() @@ -387,16 +387,14 @@ async fn proxy_handler( } // Convert the stream to an axum Body for streaming support - let stream = proxy_response - .body - .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)); + let stream = proxy_response.body.map_err(std::io::Error::other); let body = Body::from_stream(stream); response.body(body).map_err(|e| { ( StatusCode::INTERNAL_SERVER_ERROR, Json(ErrorResponse { - error: format!("Failed to build response: {}", e), + error: format!("Failed to build response: {e}"), }), ) .into_response() @@ -413,7 +411,7 @@ async fn extract_body_bytes(request: Request) -> Result { ( StatusCode::BAD_REQUEST, Json(ErrorResponse { - error: format!("Failed to read request body: {}", e), + error: format!("Failed to read request body: {e}"), }), ) .into_response() diff --git a/crates/api/src/routes/oauth.rs b/crates/api/src/routes/oauth.rs index 5c81b026..8f6f814b 100644 --- a/crates/api/src/routes/oauth.rs +++ b/crates/api/src/routes/oauth.rs @@ -100,8 +100,8 @@ pub async fn oauth_callback( ); // The provider is determined from the state stored in the database - // Returns (session, frontend_callback_url) - let (session, frontend_callback) = app_state + // Returns (session, frontend_callback_url, is_new_user) + let (session, frontend_callback, is_new_user) = app_state .oauth_service .handle_callback_unified(params.code.clone(), params.state.clone()) .await @@ -139,12 +139,15 @@ pub async fn oauth_callback( tracing::info!("Redirecting to frontend: {}", frontend_url); - let callback_url = format!( + let mut callback_url = format!( "{}/auth/callback?token={}&expires_at={}", frontend_url, urlencoding::encode(&token), urlencoding::encode(&session.expires_at.to_rfc3339()) ); + if is_new_user { + callback_url.push_str("&is_new_user=true"); + } tracing::debug!("Final callback URL: {}", callback_url); diff --git a/crates/api/tests/e2e_api_tests.rs b/crates/api/tests/e2e_api_tests.rs index a510103f..39cc20cd 100644 --- a/crates/api/tests/e2e_api_tests.rs +++ b/crates/api/tests/e2e_api_tests.rs @@ -88,13 +88,13 @@ async fn test_conversation_workflow() { .post("/v1/conversations") .add_header( http::HeaderName::from_static("authorization"), - http::HeaderValue::from_str(&format!("Bearer {}", SESSION_TOKEN)).unwrap(), + http::HeaderValue::from_str(&format!("Bearer {SESSION_TOKEN}")).unwrap(), ) .json(&create_conv_body) .await; let status = response.status_code(); - println!(" Status: {}", status); + println!(" Status: {status}"); let conversation_id = if status.is_success() { let body: serde_json::Value = response.json(); @@ -104,11 +104,11 @@ async fn test_conversation_workflow() { .and_then(|v| v.as_str()) .map(|s| s.to_string()) .expect("Conversation should have an ID"); - println!(" Conversation ID: {}", conv_id); + println!(" Conversation ID: {conv_id}"); conv_id } else { let error_text = response.text(); - println!(" ✗ Failed: {}", error_text); + println!(" ✗ Failed: {error_text}"); panic!("Failed to create conversation"); }; @@ -130,13 +130,13 @@ async fn test_conversation_workflow() { .post("/v1/responses") .add_header( http::HeaderName::from_static("authorization"), - http::HeaderValue::from_str(&format!("Bearer {}", SESSION_TOKEN)).unwrap(), + http::HeaderValue::from_str(&format!("Bearer {SESSION_TOKEN}")).unwrap(), ) .json(&request_body) .await; let status = response.status_code(); - println!(" Status: {}", status); + println!(" Status: {status}"); if status.is_success() { let body: serde_json::Value = response.json(); @@ -147,7 +147,7 @@ async fn test_conversation_workflow() { ); } else { let error_text = response.text(); - println!(" ✗ Failed: {}", error_text); + println!(" ✗ Failed: {error_text}"); panic!("Failed to create first response"); }; @@ -169,13 +169,13 @@ async fn test_conversation_workflow() { .post("/v1/responses") .add_header( http::HeaderName::from_static("authorization"), - http::HeaderValue::from_str(&format!("Bearer {}", SESSION_TOKEN)).unwrap(), + http::HeaderValue::from_str(&format!("Bearer {SESSION_TOKEN}")).unwrap(), ) .json(&request_body) .await; let status = response.status_code(); - println!(" Status: {}", status); + println!(" Status: {status}"); if status.is_success() { let body: serde_json::Value = response.json(); @@ -186,7 +186,7 @@ async fn test_conversation_workflow() { ); } else { let error_text = response.text(); - println!(" ✗ Failed: {}", error_text); + println!(" ✗ Failed: {error_text}"); panic!("Failed to create second response"); }; @@ -196,14 +196,14 @@ async fn test_conversation_workflow() { .get("/v1/conversations") .add_header( http::HeaderName::from_static("authorization"), - http::HeaderValue::from_str(&format!("Bearer {}", SESSION_TOKEN)).unwrap(), + http::HeaderValue::from_str(&format!("Bearer {SESSION_TOKEN}")).unwrap(), ) .await; assert_eq!(response.status_code(), 200, "Should list conversations"); let conversations: Vec = response.json(); - println!("{:?}", conversations); + println!("{conversations:?}"); println!(" Found {} total conversations", conversations.len()); // Find our conversation diff --git a/crates/services/src/auth/ports.rs b/crates/services/src/auth/ports.rs index fa8b3882..36ef24b1 100644 --- a/crates/services/src/auth/ports.rs +++ b/crates/services/src/auth/ports.rs @@ -103,7 +103,7 @@ pub trait OAuthService: Send + Sync { &self, code: String, state: String, - ) -> anyhow::Result<(UserSession, Option)>; + ) -> anyhow::Result<(UserSession, Option, bool)>; /// Refresh an access token async fn refresh_token( diff --git a/crates/services/src/auth/service.rs b/crates/services/src/auth/service.rs index a6a0a3b7..d7ad6f69 100644 --- a/crates/services/src/auth/service.rs +++ b/crates/services/src/auth/service.rs @@ -63,6 +63,7 @@ pub struct OAuthServiceImpl { } impl OAuthServiceImpl { + #[allow(clippy::too_many_arguments)] pub fn new( oauth_repository: Arc, session_repository: Arc, @@ -121,7 +122,7 @@ impl OAuthServiceImpl { name: user_data["name"].as_str().map(|s| s.to_string()), avatar_url: user_data["picture"].as_str().map(|s| s.to_string()), }; - + tracing::info!( "Successfully fetched Google user info: email={}, provider_user_id={}", user_info.email, @@ -169,7 +170,7 @@ impl OAuthServiceImpl { let emails: Vec = emails_response.json().await?; tracing::debug!("Github emails received: {} email(s)", emails.len()); - + let primary_email = emails .iter() .find(|e| e["primary"].as_bool().unwrap_or(false)) @@ -187,7 +188,7 @@ impl OAuthServiceImpl { name: user_data["name"].as_str().map(|s| s.to_string()), avatar_url: user_data["avatar_url"].as_str().map(|s| s.to_string()), }; - + tracing::info!( "Successfully fetched Github user info: email={}, provider_user_id={}", user_info.email, @@ -197,23 +198,25 @@ impl OAuthServiceImpl { Ok(user_info) } + /// Find or create user from OAuth + /// Returns (user_id, is_new_user) async fn find_or_create_user_from_oauth( &self, user_info: &OAuthUserInfo, - ) -> anyhow::Result { + ) -> anyhow::Result<(UserId, bool)> { tracing::info!( "Finding or creating user for OAuth login: provider={:?}, email={}", user_info.provider, user_info.email ); - + // First check if user exists by OAuth provider tracing::debug!( "Checking for existing user by OAuth: provider={:?}, provider_user_id={}", user_info.provider, user_info.provider_user_id ); - + if let Some(user_id) = self .user_repository .find_user_by_oauth(user_info.provider, &user_info.provider_user_id) @@ -224,11 +227,14 @@ impl OAuthServiceImpl { user_id, user_info.provider ); - return Ok(user_id); + return Ok((user_id, false)); } - tracing::debug!("No user found by OAuth, checking by email: {}", user_info.email); - + tracing::debug!( + "No user found by OAuth, checking by email: {}", + user_info.email + ); + // Check if user exists by email if let Some(existing_user) = self .user_repository @@ -247,7 +253,7 @@ impl OAuthServiceImpl { existing_user.id, user_info.provider ); - + self.user_repository .link_oauth_account( existing_user.id, @@ -262,7 +268,7 @@ impl OAuthServiceImpl { existing_user.id ); - return Ok(existing_user.id); + return Ok((existing_user.id, false)); } // Create new user @@ -270,7 +276,7 @@ impl OAuthServiceImpl { "No existing user found, creating new user for email: {}", user_info.email ); - + let user = self .user_repository .create_user( @@ -292,7 +298,7 @@ impl OAuthServiceImpl { user.id, user_info.provider ); - + self.user_repository .link_oauth_account( user.id, @@ -307,7 +313,7 @@ impl OAuthServiceImpl { user.id ); - Ok(user.id) + Ok((user.id, true)) } /// Internal implementation that handles the callback with a pre-validated state @@ -317,13 +323,13 @@ impl OAuthServiceImpl { provider: OAuthProvider, code: String, oauth_state: OAuthState, - ) -> anyhow::Result<(UserSession, Option)> { + ) -> anyhow::Result<(UserSession, Option, bool)> { tracing::info!( "Processing OAuth callback: provider={:?}, redirect_uri={}", provider, oauth_state.redirect_uri ); - + // Build client for token exchange let (client_id, client_secret, auth_url, token_url) = match provider { OAuthProvider::Google => ( @@ -340,8 +346,11 @@ impl OAuthServiceImpl { ), }; - tracing::debug!("Building OAuth client for token exchange with provider: {:?}", provider); - + tracing::debug!( + "Building OAuth client for token exchange with provider: {:?}", + provider + ); + let client = BasicClient::new(ClientId::new(client_id.clone())) .set_client_secret(ClientSecret::new(client_secret.clone())) .set_auth_uri(AuthUrl::new(auth_url.to_string())?) @@ -349,7 +358,7 @@ impl OAuthServiceImpl { .set_redirect_uri(RedirectUrl::new(oauth_state.redirect_uri.clone())?); tracing::debug!("Exchanging authorization code for access token"); - + let token_result = client .exchange_code(AuthorizationCode::new(code)) .request_async(&async_http_client) @@ -360,11 +369,11 @@ impl OAuthServiceImpl { })?; tracing::info!("Successfully exchanged authorization code for access token"); - + let access_token = token_result.access_token().secret(); let has_refresh_token = token_result.refresh_token().is_some(); let expires_in = token_result.expires_in(); - + tracing::debug!( "Token details: has_refresh_token={}, expires_in={:?}", has_refresh_token, @@ -379,7 +388,7 @@ impl OAuthServiceImpl { }; // Find or create user - let user_id = self.find_or_create_user_from_oauth(&user_info).await?; + let (user_id, is_new_user) = self.find_or_create_user_from_oauth(&user_info).await?; // Store OAuth tokens let oauth_tokens = OAuthTokens { @@ -395,7 +404,7 @@ impl OAuthServiceImpl { user_id, provider ); - + self.oauth_repository .store_oauth_tokens(user_id, provider, &oauth_tokens) .await?; @@ -413,7 +422,7 @@ impl OAuthServiceImpl { session.session_id ); - Ok((session, oauth_state.frontend_callback)) + Ok((session, oauth_state.frontend_callback, is_new_user)) } } @@ -431,7 +440,7 @@ impl OAuthService for OAuthServiceImpl { redirect_uri, frontend_callback ); - + let (client_id, client_secret, auth_url, token_url, scopes) = match provider { OAuthProvider::Google => ( &self.google_client_id, @@ -463,7 +472,7 @@ impl OAuthService for OAuthServiceImpl { } let (auth_url, csrf_token) = auth_request.url(); - + tracing::debug!("Generated CSRF token: {}", csrf_token.secret()); // Store the state for verification @@ -476,7 +485,7 @@ impl OAuthService for OAuthServiceImpl { }; tracing::debug!("Storing OAuth state in database"); - + self.oauth_repository .store_oauth_state(&oauth_state) .await?; @@ -557,9 +566,9 @@ impl OAuthService for OAuthServiceImpl { &self, code: String, state: String, - ) -> anyhow::Result<(UserSession, Option)> { + ) -> anyhow::Result<(UserSession, Option, bool)> { tracing::info!("Handling unified OAuth callback with state: {}", state); - + // First, look up the state to determine the provider tracing::debug!("Looking up OAuth state in database"); let oauth_state = self @@ -584,9 +593,9 @@ impl OAuthService for OAuthServiceImpl { async fn revoke_session(&self, session_id: SessionId) -> anyhow::Result<()> { tracing::info!("Revoking session: session_id={}", session_id); - + self.session_repository.delete_session(session_id).await?; - + tracing::info!("Session revoked successfully: session_id={}", session_id); Ok(()) } diff --git a/crates/services/src/types.rs b/crates/services/src/types.rs index a5e95bbf..40f7f1de 100644 --- a/crates/services/src/types.rs +++ b/crates/services/src/types.rs @@ -133,7 +133,7 @@ mod tests { fn test_id_display() { let uuid = Uuid::new_v4(); let user_id = UserId(uuid); - assert_eq!(format!("{}", user_id), format!("{}", uuid)); + assert_eq!(format!("{user_id}"), format!("{}", uuid)); } #[test] diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 00000000..961b47e2 --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,3 @@ +[toolchain] +channel = "1.88.0" +components = ["cargo", "rustfmt", "clippy"]