diff --git a/src/rotator_library/providers/antigravity_auth_base.py b/src/rotator_library/providers/antigravity_auth_base.py index 0e464570..2902dc65 100644 --- a/src/rotator_library/providers/antigravity_auth_base.py +++ b/src/rotator_library/providers/antigravity_auth_base.py @@ -10,16 +10,18 @@ import httpx from .google_oauth_base import GoogleOAuthBase -from .utilities.gemini_shared_utils import CODE_ASSIST_ENDPOINT +# Note: Endpoint constants are imported by helper methods from gemini_shared_utils lib_logger = logging.getLogger("rotator_library") # Headers for Antigravity auth/discovery calls +# Uses Gemini CLI style User-Agent/X-Goog-Api-Client for compatibility with newer API versions, +# with Antigravity-specific Client-Metadata (JSON format) # Note: ideType in Client-Metadata header stays IDE_UNSPECIFIED for compatibility, # while ideType in request body metadata uses "ANTIGRAVITY" ANTIGRAVITY_AUTH_HEADERS = { - "User-Agent": "google-api-nodejs-client/9.15.1", - "X-Goog-Api-Client": "google-cloud-sdk vscode_cloudshelleditor/0.1", + "User-Agent": "google-api-nodejs-client/10.3.0", + "X-Goog-Api-Client": "gl-node/22.18.0", "Client-Metadata": '{"ideType":"IDE_UNSPECIFIED","platform":"PLATFORM_UNSPECIFIED","pluginType":"GEMINI"}', } @@ -100,6 +102,147 @@ async def _post_auth_discovery( f"tier={tier}, project={project_id}" ) + # ========================================================================= + # ENDPOINT FALLBACK HELPERS + # ========================================================================= + + def _extract_project_id_from_response( + self, data: Dict[str, Any], key: str = "cloudaicompanionProject" + ) -> Optional[str]: + """ + Extract project ID from API response, handling both string and object formats. + + The API may return cloudaicompanionProject as either: + - A string: "project-id-123" + - An object: {"id": "project-id-123", ...} + + Args: + data: API response data + key: Key to extract from (default: "cloudaicompanionProject") + + Returns: + Project ID string or None if not found + """ + value = data.get(key) + if isinstance(value, str) and value: + return value + if isinstance(value, dict): + return value.get("id") + return None + + async def _call_load_code_assist( + self, + client: httpx.AsyncClient, + access_token: str, + configured_project_id: Optional[str], + headers: Dict[str, str], + ) -> tuple: + """ + Call loadCodeAssist with endpoint fallback chain. + + Tries endpoints in ANTIGRAVITY_LOAD_ENDPOINT_ORDER (prod first for better + project resolution, then fallback to sandbox). + + Args: + client: httpx async client + access_token: OAuth access token + configured_project_id: User-configured project ID (or None) + headers: Request headers + + Returns: + Tuple of (response_data, successful_endpoint) or (None, None) on failure + """ + from .utilities.gemini_shared_utils import ANTIGRAVITY_LOAD_ENDPOINT_ORDER + + core_client_metadata = { + "ideType": "ANTIGRAVITY", + "platform": "PLATFORM_UNSPECIFIED", + "pluginType": "GEMINI", + } + if configured_project_id: + core_client_metadata["duetProject"] = configured_project_id + + load_request = { + "cloudaicompanionProject": configured_project_id, + "metadata": core_client_metadata, + } + + last_error = None + for endpoint in ANTIGRAVITY_LOAD_ENDPOINT_ORDER: + try: + lib_logger.debug(f"Trying loadCodeAssist at {endpoint}") + response = await client.post( + f"{endpoint}:loadCodeAssist", + headers=headers, + json=load_request, + timeout=15, + ) + if response.status_code == 200: + data = response.json() + lib_logger.debug(f"loadCodeAssist succeeded at {endpoint}") + return data, endpoint + lib_logger.debug( + f"loadCodeAssist returned {response.status_code} at {endpoint}" + ) + last_error = f"HTTP {response.status_code}" + except Exception as e: + lib_logger.debug(f"loadCodeAssist failed at {endpoint}: {e}") + last_error = str(e) + continue + + lib_logger.warning( + f"All loadCodeAssist endpoints failed. Last error: {last_error}" + ) + return None, None + + async def _call_onboard_user( + self, + client: httpx.AsyncClient, + headers: Dict[str, str], + onboard_request: Dict[str, Any], + ) -> Optional[Dict[str, Any]]: + """ + Call onboardUser with endpoint fallback chain. + + Tries endpoints in ANTIGRAVITY_ENDPOINT_FALLBACKS (daily first, then prod). + + Args: + client: httpx async client + headers: Request headers + onboard_request: Onboarding request payload + + Returns: + Response data dict or None on failure + """ + from .utilities.gemini_shared_utils import ANTIGRAVITY_ENDPOINT_FALLBACKS + + last_error = None + for endpoint in ANTIGRAVITY_ENDPOINT_FALLBACKS: + try: + lib_logger.debug(f"Trying onboardUser at {endpoint}") + response = await client.post( + f"{endpoint}:onboardUser", + headers=headers, + json=onboard_request, + timeout=30, + ) + if response.status_code == 200: + lib_logger.debug(f"onboardUser succeeded at {endpoint}") + return response.json() + lib_logger.debug( + f"onboardUser returned {response.status_code} at {endpoint}" + ) + last_error = f"HTTP {response.status_code}" + except Exception as e: + lib_logger.debug(f"onboardUser failed at {endpoint}: {e}") + last_error = str(e) + continue + + lib_logger.warning( + f"All onboardUser endpoints failed. Last error: {last_error}" + ) + return None + # ========================================================================= # PROJECT ID DISCOVERY # ========================================================================= @@ -203,45 +346,39 @@ async def _discover_project_id( **ANTIGRAVITY_AUTH_HEADERS, } + # Build core metadata for API requests + core_client_metadata = { + "ideType": "ANTIGRAVITY", + "platform": "PLATFORM_UNSPECIFIED", + "pluginType": "GEMINI", + } + if configured_project_id: + core_client_metadata["duetProject"] = configured_project_id + discovered_project_id = None discovered_tier = None async with httpx.AsyncClient() as client: - # 1. Try discovery endpoint with loadCodeAssist + # 1. Try discovery endpoint with loadCodeAssist using endpoint fallback lib_logger.debug( "Attempting project discovery via Code Assist loadCodeAssist endpoint..." ) try: - # Build metadata - include duetProject only if we have a configured project - core_client_metadata = { - "ideType": "ANTIGRAVITY", - "platform": "PLATFORM_UNSPECIFIED", - "pluginType": "GEMINI", - } - if configured_project_id: - core_client_metadata["duetProject"] = configured_project_id - - # Build load request - pass configured_project_id if available, otherwise None - load_request = { - "cloudaicompanionProject": configured_project_id, # Can be None - "metadata": core_client_metadata, - } - - lib_logger.debug( - f"Sending loadCodeAssist request with cloudaicompanionProject={configured_project_id}" - ) - response = await client.post( - f"{CODE_ASSIST_ENDPOINT}:loadCodeAssist", - headers=headers, - json=load_request, - timeout=20, + # Use helper with endpoint fallback chain + data, successful_endpoint = await self._call_load_code_assist( + client, access_token, configured_project_id, headers ) - response.raise_for_status() - data = response.json() - # Log full response for debugging + if data is None: + # All endpoints failed - skip to GCP Resource Manager fallback + raise httpx.HTTPStatusError( + "All loadCodeAssist endpoints failed", + request=None, + response=None, + ) + lib_logger.debug( - f"loadCodeAssist full response keys: {list(data.keys())}" + f"loadCodeAssist succeeded at {successful_endpoint}, response keys: {list(data.keys())}" ) # Extract tier information @@ -269,7 +406,8 @@ async def _discover_project_id( # Check if user is already known to server (has currentTier) if current_tier_id: # User is already onboarded - check for project from server - server_project = data.get("cloudaicompanionProject") + # Use helper to handle both string and object formats + server_project = self._extract_project_id_from_response(data) # Check if this tier requires user-defined project (paid tiers) requires_user_project = any( @@ -405,58 +543,52 @@ async def _discover_project_id( f"Paid tier onboarding: using project {configured_project_id}" ) - lib_logger.debug("Initiating onboardUser request...") - lro_response = await client.post( - f"{CODE_ASSIST_ENDPOINT}:onboardUser", - headers=headers, - json=onboard_request, - timeout=30, + lib_logger.debug( + "Initiating onboardUser request with endpoint fallback..." ) - lro_response.raise_for_status() - lro_data = lro_response.json() + lro_data = await self._call_onboard_user( + client, headers, onboard_request + ) + + if lro_data is None: + raise ValueError( + "All onboardUser endpoints failed. Cannot onboard user." + ) + lib_logger.debug( f"Initial onboarding response: done={lro_data.get('done')}" ) - # Poll for onboarding completion (up to 5 minutes) - for i in range(150): # 150 × 2s = 5 minutes + # Poll for onboarding completion (up to 60 seconds) + for i in range(30): # 30 × 2s = 60 seconds if lro_data.get("done"): - lib_logger.debug( - f"Onboarding completed after {i} polling attempts" - ) + lib_logger.debug(f"Onboarding completed after {i * 2}s") break await asyncio.sleep(2) - if (i + 1) % 15 == 0: # Log every 30 seconds + if (i + 1) % 10 == 0: # Log every 20 seconds lib_logger.info( f"Still waiting for onboarding completion... ({(i + 1) * 2}s elapsed)" ) lib_logger.debug( - f"Polling onboarding status... (Attempt {i + 1}/150)" + f"Polling onboarding status... (Attempt {i + 1}/30)" ) - lro_response = await client.post( - f"{CODE_ASSIST_ENDPOINT}:onboardUser", - headers=headers, - json=onboard_request, - timeout=30, + lro_data = await self._call_onboard_user( + client, headers, onboard_request ) - lro_response.raise_for_status() - lro_data = lro_response.json() + if lro_data is None: + lib_logger.warning("onboardUser endpoint failed during polling") + break - if not lro_data.get("done"): - lib_logger.error("Onboarding process timed out after 5 minutes") + if not lro_data or not lro_data.get("done"): + lib_logger.error("Onboarding process timed out after 60 seconds") raise ValueError( - "Onboarding process timed out after 5 minutes. Please try again or contact support." + "Onboarding process timed out after 60 seconds. Please try again or contact support." ) - # Extract project ID from LRO response + # Extract project ID from LRO response using helper # Note: onboardUser returns response.cloudaicompanionProject as an object with .id lro_response_data = lro_data.get("response", {}) - lro_project_obj = lro_response_data.get("cloudaicompanionProject", {}) - project_id = ( - lro_project_obj.get("id") - if isinstance(lro_project_obj, dict) - else None - ) + project_id = self._extract_project_id_from_response(lro_response_data) # Fallback to configured project if LRO didn't return one if not project_id and configured_project_id: diff --git a/src/rotator_library/providers/google_oauth_base.py b/src/rotator_library/providers/google_oauth_base.py index 493d39c1..60dc6fea 100644 --- a/src/rotator_library/providers/google_oauth_base.py +++ b/src/rotator_library/providers/google_oauth_base.py @@ -93,6 +93,77 @@ class GoogleOAuthBase: CALLBACK_PATH: str = DEFAULT_OAUTH_CALLBACK_PATH REFRESH_EXPIRY_BUFFER_SECONDS: int = DEFAULT_REFRESH_EXPIRY_BUFFER + # ========================================================================= + # PKCE (Proof Key for Code Exchange) SUPPORT + # ========================================================================= + + def _generate_pkce(self) -> tuple: + """ + Generate PKCE code_verifier and code_challenge. + + PKCE (Proof Key for Code Exchange) prevents authorization code interception attacks. + Required for public OAuth clients per RFC 7636. + + Returns: + Tuple of (code_verifier, code_challenge) + """ + import secrets + import hashlib + import base64 + + # code_verifier: 43-128 chars, using URL-safe base64 + code_verifier = secrets.token_urlsafe(32) # Produces 43 chars + + # code_challenge: BASE64URL(SHA256(code_verifier)) without padding + digest = hashlib.sha256(code_verifier.encode("ascii")).digest() + code_challenge = base64.urlsafe_b64encode(digest).decode("ascii").rstrip("=") + + return code_verifier, code_challenge + + def _encode_oauth_state(self, code_verifier: str) -> str: + """ + Encode OAuth state parameter containing PKCE verifier. + + The state parameter provides CSRF protection and carries the PKCE verifier + so it can be recovered after the OAuth callback. + + Args: + code_verifier: The PKCE code verifier to encode + + Returns: + Base64url-encoded state string + """ + import base64 + + state_data = {"v": code_verifier} # Minimal - just verifier + json_bytes = json.dumps(state_data, separators=(",", ":")).encode("utf-8") + return base64.urlsafe_b64encode(json_bytes).decode("ascii").rstrip("=") + + def _decode_oauth_state(self, state: str) -> str: + """ + Decode OAuth state and return code_verifier. + + Args: + state: The base64url-encoded state string from OAuth callback + + Returns: + The decoded code_verifier + + Raises: + ValueError: If state cannot be decoded or is missing verifier + """ + import base64 + + # Re-pad base64 string (base64url encoding strips padding) + padded = state + "=" * (-len(state) % 4) + try: + state_data = json.loads(base64.urlsafe_b64decode(padded).decode("utf-8")) + if "v" not in state_data: + raise ValueError("Missing verifier in state") + return state_data["v"] + except (json.JSONDecodeError, UnicodeDecodeError) as e: + raise ValueError(f"Invalid state parameter: {e}") + @property def callback_port(self) -> int: """ @@ -912,6 +983,12 @@ async def _perform_interactive_oauth( # [HEADLESS DETECTION] Check if running in headless environment is_headless = is_headless_environment() + # [PKCE] Generate PKCE code verifier and challenge for enhanced security + code_verifier, code_challenge = self._generate_pkce() + + # [STATE] Encode state parameter with PKCE verifier for CSRF protection + oauth_state = self._encode_oauth_state(code_verifier) + auth_code_future = asyncio.get_event_loop().create_future() server = None @@ -928,8 +1005,13 @@ async def handle_callback(reader, writer): query_params = parse_qs(urlparse(path_str).query) writer.write(b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n") if "code" in query_params: + # Extract code and state from callback + auth_code = query_params["code"][0] + received_state = query_params.get("state", [None])[0] + if not auth_code_future.done(): - auth_code_future.set_result(query_params["code"][0]) + # Return both code and state for validation + auth_code_future.set_result((auth_code, received_state)) writer.write( b"
You can close this window.
" ) @@ -954,6 +1036,7 @@ async def handle_callback(reader, writer): ) from urllib.parse import urlencode + # [PKCE + STATE] Include code_challenge, code_challenge_method, and state in auth URL auth_url = "https://accounts.google.com/o/oauth2/v2/auth?" + urlencode( { "client_id": self.CLIENT_ID, @@ -962,6 +1045,9 @@ async def handle_callback(reader, writer): "access_type": "offline", "response_type": "code", "prompt": "consent", + "code_challenge": code_challenge, + "code_challenge_method": "S256", + "state": oauth_state, } ) @@ -1016,7 +1102,30 @@ async def handle_callback(reader, writer): ): # Note: The 300s timeout here is handled by the ReauthCoordinator # We use a slightly longer internal timeout to let the coordinator handle it - auth_code = await asyncio.wait_for(auth_code_future, timeout=310) + auth_code, received_state = await asyncio.wait_for( + auth_code_future, timeout=310 + ) + + # [STATE VALIDATION] Validate state parameter and extract verifier + effective_verifier = code_verifier # Default to original verifier + if received_state: + try: + decoded_verifier = self._decode_oauth_state(received_state) + if decoded_verifier != code_verifier: + lib_logger.warning( + "OAuth state verifier mismatch - possible CSRF attempt. " + "Using original verifier." + ) + else: + effective_verifier = decoded_verifier + except ValueError as e: + lib_logger.warning( + f"Failed to decode OAuth state: {e}. Using original verifier." + ) + else: + lib_logger.debug( + "No state parameter in callback - using original verifier" + ) except asyncio.TimeoutError: raise Exception("OAuth flow timed out. Please try again.") finally: @@ -1026,14 +1135,22 @@ async def handle_callback(reader, writer): lib_logger.info(f"Attempting to exchange authorization code for tokens...") async with httpx.AsyncClient() as client: + # [PKCE + HEADERS] Include code_verifier and explicit headers for token exchange response = await client.post( self.TOKEN_URI, + headers={ + "Accept": "*/*", + "Content-Type": "application/x-www-form-urlencoded;charset=UTF-8", + "User-Agent": "google-api-nodejs-client/10.3.0", + "X-Goog-Api-Client": "gl-node/22.18.0", + }, data={ "code": auth_code.strip(), "client_id": self.CLIENT_ID, "client_secret": self.CLIENT_SECRET, "redirect_uri": f"http://localhost:{self.callback_port}{self.CALLBACK_PATH}", "grant_type": "authorization_code", + "code_verifier": effective_verifier, }, ) response.raise_for_status() diff --git a/src/rotator_library/providers/utilities/gemini_shared_utils.py b/src/rotator_library/providers/utilities/gemini_shared_utils.py index 05d36d98..d1c26291 100644 --- a/src/rotator_library/providers/utilities/gemini_shared_utils.py +++ b/src/rotator_library/providers/utilities/gemini_shared_utils.py @@ -47,6 +47,32 @@ def env_int(key: str, default: int) -> int: "https://cloudcode-pa.googleapis.com/v1internal", # Production fallback ] +# ============================================================================= +# ANTIGRAVITY ENDPOINTS +# ============================================================================= + +# Antigravity API endpoint constants +# Sandbox endpoints often have different rate limits or newer features +ANTIGRAVITY_ENDPOINT_DAILY = ( + "https://daily-cloudcode-pa.sandbox.googleapis.com/v1internal" +) +ANTIGRAVITY_ENDPOINT_PROD = "https://cloudcode-pa.googleapis.com/v1internal" +# ANTIGRAVITY_ENDPOINT_AUTOPUSH = "https://autopush-cloudcode-pa.sandbox.googleapis.com/v1internal" # Reserved for future use + +# Antigravity endpoint fallback chain for API requests +# Order: sandbox daily -> production (matches CLIProxy/Vibeproxy behavior) +ANTIGRAVITY_ENDPOINT_FALLBACKS = [ + ANTIGRAVITY_ENDPOINT_DAILY, # Daily sandbox first + ANTIGRAVITY_ENDPOINT_PROD, # Production fallback +] + +# Endpoint order for loadCodeAssist (project discovery) +# Production first for better project resolution, then fallback to sandbox +ANTIGRAVITY_LOAD_ENDPOINT_ORDER = [ + ANTIGRAVITY_ENDPOINT_PROD, # Prod first for discovery + ANTIGRAVITY_ENDPOINT_DAILY, # Daily fallback +] + # ============================================================================= # GEMINI 3 TOOL RENAMING CONSTANTS