|
| 1 | +""" |
| 2 | +Copyright (C) 2026 Garudex Labs. All Rights Reserved. |
| 3 | +Caracal, a product of Garudex Labs |
| 4 | +
|
| 5 | +SDK Gateway Adapter. |
| 6 | +
|
| 7 | +Provides a transport adapter that routes mandate issuance, validation, |
| 8 | +and revocation through the enterprise gateway instead of the Caracal API |
| 9 | +directly. |
| 10 | +
|
| 11 | +OSS: behaves identically to the standard HTTP adapter (broker mode). |
| 12 | +Enterprise: wraps every request with the gateway's auth headers and routes |
| 13 | + through CARACAL_GATEWAY_ENDPOINT, gaining network-level enforcement. |
| 14 | +
|
| 15 | +Usage (automatic — gateway flags from environment / config): |
| 16 | +
|
| 17 | + from caracal.sdk.client import CaracalClient |
| 18 | +
|
| 19 | + # Feature flags auto-detected; GatewayAdapter used when gateway_enabled=True |
| 20 | + client = CaracalClient(api_key="…") |
| 21 | + mandate = await client.scope(...).mandates.create(...) |
| 22 | +
|
| 23 | +Manual override: |
| 24 | +
|
| 25 | + from caracal.sdk.gateway import GatewayAdapter |
| 26 | + from caracal.sdk.adapters.base import SDKRequest |
| 27 | +
|
| 28 | + adapter = GatewayAdapter( |
| 29 | + gateway_endpoint="https://gw.example.com", |
| 30 | + gateway_api_key="gw_key", |
| 31 | + org_id="org_123", |
| 32 | + ) |
| 33 | + client = CaracalClient(adapter=adapter) |
| 34 | +""" |
| 35 | + |
| 36 | +from __future__ import annotations |
| 37 | + |
| 38 | +import asyncio |
| 39 | +import hashlib |
| 40 | +import time |
| 41 | +from typing import Any, Dict, Optional |
| 42 | + |
| 43 | +import httpx |
| 44 | + |
| 45 | +from caracal.core.gateway_features import ( |
| 46 | + GatewayFeatureFlags, |
| 47 | + get_gateway_features, |
| 48 | + DEPLOYMENT_OSS, |
| 49 | +) |
| 50 | +from caracal.logging_config import get_logger |
| 51 | +from caracal.sdk.adapters.base import BaseAdapter, SDKRequest, SDKResponse |
| 52 | + |
| 53 | +logger = get_logger(__name__) |
| 54 | + |
| 55 | +_GW_REQUEST_TIMEOUT = 30 |
| 56 | + |
| 57 | + |
| 58 | +class GatewayAdapterError(Exception): |
| 59 | + """Raised when the gateway adapter encounters a non-retriable error.""" |
| 60 | + |
| 61 | + |
| 62 | +class GatewayAdapter(BaseAdapter): |
| 63 | + """ |
| 64 | + Transport adapter that proxies SDK requests through the Caracal |
| 65 | + enterprise gateway. |
| 66 | +
|
| 67 | + In OSS / broker mode this adapter calls the Caracal API directly |
| 68 | + (identical to HttpAdapter) unless *gateway_endpoint* is set. |
| 69 | +
|
| 70 | + In enterprise mode (gateway_endpoint + api_key configured) every |
| 71 | + request is forwarded to the gateway, which performs: |
| 72 | + - Mandate revocation check (fail-closed) |
| 73 | + - Provider registry resolution |
| 74 | + - Per-tenant quota enforcement |
| 75 | + - Secret binding for upstream credentials |
| 76 | + - Metering event emission |
| 77 | + """ |
| 78 | + |
| 79 | + # Headers injected on every outbound request |
| 80 | + GATEWAY_API_KEY_HEADER = "X-Gateway-Key" |
| 81 | + GATEWAY_ORG_HEADER = "X-Caracal-Org-ID" |
| 82 | + GATEWAY_WORKSPACE_HEADER = "X-Caracal-Workspace-ID" |
| 83 | + |
| 84 | + def __init__( |
| 85 | + self, |
| 86 | + gateway_endpoint: Optional[str] = None, |
| 87 | + gateway_api_key: Optional[str] = None, |
| 88 | + org_id: Optional[str] = None, |
| 89 | + workspace_id: Optional[str] = None, |
| 90 | + fallback_base_url: Optional[str] = None, |
| 91 | + timeout_seconds: int = _GW_REQUEST_TIMEOUT, |
| 92 | + feature_flags: Optional[GatewayFeatureFlags] = None, |
| 93 | + ) -> None: |
| 94 | + """ |
| 95 | + Args: |
| 96 | + gateway_endpoint: Base URL of the enterprise gateway proxy. |
| 97 | + Defaults to CARACAL_GATEWAY_ENDPOINT env var. |
| 98 | + gateway_api_key: API key for gateway authentication. |
| 99 | + Defaults to CARACAL_GATEWAY_API_KEY env var. |
| 100 | + org_id: Organization identifier injected into every request. |
| 101 | + workspace_id: Workspace identifier injected into every request. |
| 102 | + fallback_base_url: Caracal API URL used when gateway is disabled. |
| 103 | + timeout_seconds: HTTP request timeout. |
| 104 | + feature_flags: Pre-loaded feature flags (loaded from env if None). |
| 105 | + """ |
| 106 | + self._flags = feature_flags or get_gateway_features() |
| 107 | + self._endpoint = ( |
| 108 | + gateway_endpoint or self._flags.gateway_endpoint or "" |
| 109 | + ).rstrip("/") |
| 110 | + self._api_key = gateway_api_key or self._flags.gateway_api_key or "" |
| 111 | + self._org_id = org_id or "" |
| 112 | + self._workspace_id = workspace_id or "" |
| 113 | + self._fallback_base = (fallback_base_url or "").rstrip("/") |
| 114 | + self._timeout = timeout_seconds |
| 115 | + |
| 116 | + self._client: Optional[httpx.AsyncClient] = None |
| 117 | + self._connected = False |
| 118 | + |
| 119 | + # ── BaseAdapter interface ───────────────────────────────────────────────── |
| 120 | + |
| 121 | + @property |
| 122 | + def is_connected(self) -> bool: |
| 123 | + return self._connected |
| 124 | + |
| 125 | + async def send(self, request: SDKRequest) -> SDKResponse: |
| 126 | + """Route the request through the gateway (or direct API in OSS mode).""" |
| 127 | + client = self._get_client() |
| 128 | + |
| 129 | + if self._should_use_gateway(): |
| 130 | + return await self._send_via_gateway(client, request) |
| 131 | + return await self._send_direct(client, request) |
| 132 | + |
| 133 | + def close(self) -> None: |
| 134 | + if self._client: |
| 135 | + try: |
| 136 | + loop = asyncio.get_event_loop() |
| 137 | + if loop.is_running(): |
| 138 | + loop.create_task(self._client.aclose()) |
| 139 | + else: |
| 140 | + loop.run_until_complete(self._client.aclose()) |
| 141 | + except Exception: |
| 142 | + pass |
| 143 | + self._client = None |
| 144 | + self._connected = False |
| 145 | + |
| 146 | + # ── Internal helpers ────────────────────────────────────────────────────── |
| 147 | + |
| 148 | + def _should_use_gateway(self) -> bool: |
| 149 | + return bool(self._flags.gateway_enabled and self._endpoint and self._flags.is_enterprise) |
| 150 | + |
| 151 | + def _get_client(self) -> httpx.AsyncClient: |
| 152 | + if not self._client: |
| 153 | + self._client = httpx.AsyncClient( |
| 154 | + timeout=httpx.Timeout(self._timeout), |
| 155 | + follow_redirects=True, |
| 156 | + ) |
| 157 | + self._connected = True |
| 158 | + return self._client |
| 159 | + |
| 160 | + async def _send_via_gateway( |
| 161 | + self, client: httpx.AsyncClient, request: SDKRequest |
| 162 | + ) -> SDKResponse: |
| 163 | + """Forward request to the enterprise gateway proxy.""" |
| 164 | + url = f"{self._endpoint}{request.path}" |
| 165 | + |
| 166 | + headers = dict(request.headers) |
| 167 | + # Inject gateway auth headers |
| 168 | + if self._api_key: |
| 169 | + headers[self.GATEWAY_API_KEY_HEADER] = self._api_key |
| 170 | + if self._org_id: |
| 171 | + headers[self.GATEWAY_ORG_HEADER] = self._org_id |
| 172 | + if self._workspace_id: |
| 173 | + headers[self.GATEWAY_WORKSPACE_HEADER] = self._workspace_id |
| 174 | + |
| 175 | + # Signal to gateway that this is an SDK call (not a direct agent forward) |
| 176 | + headers["X-Caracal-SDK-Call"] = "1" |
| 177 | + headers["X-Caracal-Deployment"] = self._flags.deployment_type |
| 178 | + |
| 179 | + start = time.monotonic() |
| 180 | + try: |
| 181 | + if request.method.upper() in ("GET", "DELETE", "HEAD"): |
| 182 | + resp = await client.request( |
| 183 | + method=request.method, |
| 184 | + url=url, |
| 185 | + headers=headers, |
| 186 | + params=request.params, |
| 187 | + ) |
| 188 | + else: |
| 189 | + resp = await client.request( |
| 190 | + method=request.method, |
| 191 | + url=url, |
| 192 | + headers=headers, |
| 193 | + params=request.params, |
| 194 | + json=request.body, |
| 195 | + ) |
| 196 | + except httpx.TimeoutException as exc: |
| 197 | + if self._flags.fail_closed: |
| 198 | + raise GatewayAdapterError( |
| 199 | + f"Gateway request timed out (fail-closed): {exc}" |
| 200 | + ) from exc |
| 201 | + logger.warning("Gateway timeout; falling back to direct API: %s", exc) |
| 202 | + return await self._send_direct(client, request) |
| 203 | + except httpx.HTTPError as exc: |
| 204 | + if self._flags.fail_closed: |
| 205 | + raise GatewayAdapterError( |
| 206 | + f"Gateway unreachable (fail-closed): {exc}" |
| 207 | + ) from exc |
| 208 | + logger.warning("Gateway unreachable; falling back to direct API: %s", exc) |
| 209 | + return await self._send_direct(client, request) |
| 210 | + |
| 211 | + elapsed = (time.monotonic() - start) * 1000 |
| 212 | + self._raise_if_gateway_error(resp) |
| 213 | + |
| 214 | + return SDKResponse( |
| 215 | + status_code=resp.status_code, |
| 216 | + headers=dict(resp.headers), |
| 217 | + body=self._parse_body(resp), |
| 218 | + elapsed_ms=elapsed, |
| 219 | + ) |
| 220 | + |
| 221 | + async def _send_direct( |
| 222 | + self, client: httpx.AsyncClient, request: SDKRequest |
| 223 | + ) -> SDKResponse: |
| 224 | + """Direct call to the Caracal API (OSS broker path).""" |
| 225 | + base = self._fallback_base |
| 226 | + if not base: |
| 227 | + raise GatewayAdapterError( |
| 228 | + "No gateway endpoint and no fallback_base_url configured." |
| 229 | + ) |
| 230 | + url = f"{base}{request.path}" |
| 231 | + start = time.monotonic() |
| 232 | + if request.method.upper() in ("GET", "DELETE", "HEAD"): |
| 233 | + resp = await client.request( |
| 234 | + method=request.method, |
| 235 | + url=url, |
| 236 | + headers=request.headers, |
| 237 | + params=request.params, |
| 238 | + ) |
| 239 | + else: |
| 240 | + resp = await client.request( |
| 241 | + method=request.method, |
| 242 | + url=url, |
| 243 | + headers=request.headers, |
| 244 | + params=request.params, |
| 245 | + json=request.body, |
| 246 | + ) |
| 247 | + elapsed = (time.monotonic() - start) * 1000 |
| 248 | + return SDKResponse( |
| 249 | + status_code=resp.status_code, |
| 250 | + headers=dict(resp.headers), |
| 251 | + body=self._parse_body(resp), |
| 252 | + elapsed_ms=elapsed, |
| 253 | + ) |
| 254 | + |
| 255 | + def _raise_if_gateway_error(self, resp: httpx.Response) -> None: |
| 256 | + """Translate gateway-specific error codes to typed exceptions.""" |
| 257 | + if resp.status_code == 401: |
| 258 | + raise GatewayAdapterError("Gateway rejected API key (401 Unauthorized).") |
| 259 | + if resp.status_code == 403: |
| 260 | + body = self._parse_body(resp) or {} |
| 261 | + error = body.get("error", "forbidden") if isinstance(body, dict) else "forbidden" |
| 262 | + if error == "mandate_revoked": |
| 263 | + from caracal.exceptions import AuthorityDeniedError |
| 264 | + raise AuthorityDeniedError("Mandate has been revoked.") |
| 265 | + if error == "provider_not_allowed": |
| 266 | + raise GatewayAdapterError( |
| 267 | + f"Provider not in registry: {body.get('message', '')}" |
| 268 | + ) |
| 269 | + raise GatewayAdapterError(f"Gateway denied request: {error}") |
| 270 | + if resp.status_code == 429: |
| 271 | + body = self._parse_body(resp) or {} |
| 272 | + raise GatewayAdapterError( |
| 273 | + f"Quota exceeded: {body.get('dimension', 'unknown')} " |
| 274 | + f"({body.get('current')}/{body.get('limit')})" |
| 275 | + ) |
| 276 | + if resp.status_code == 503: |
| 277 | + raise GatewayAdapterError("Gateway unavailable (503).") |
| 278 | + |
| 279 | + @staticmethod |
| 280 | + def _parse_body(resp: httpx.Response) -> Any: |
| 281 | + ct = resp.headers.get("content-type", "") |
| 282 | + if "application/json" in ct: |
| 283 | + try: |
| 284 | + return resp.json() |
| 285 | + except Exception: |
| 286 | + return resp.text |
| 287 | + return resp.text or None |
| 288 | + |
| 289 | + |
| 290 | +def build_gateway_adapter( |
| 291 | + org_id: Optional[str] = None, |
| 292 | + workspace_id: Optional[str] = None, |
| 293 | + fallback_base_url: Optional[str] = None, |
| 294 | +) -> GatewayAdapter: |
| 295 | + """ |
| 296 | + Convenience factory: build a GatewayAdapter from environment feature flags. |
| 297 | +
|
| 298 | + Returns a GatewayAdapter configured from CARACAL_GATEWAY_* env vars. |
| 299 | + OSS users without gateway flags configured will get a simple direct adapter. |
| 300 | + """ |
| 301 | + flags = get_gateway_features() |
| 302 | + return GatewayAdapter( |
| 303 | + gateway_endpoint=flags.gateway_endpoint, |
| 304 | + gateway_api_key=flags.gateway_api_key, |
| 305 | + org_id=org_id, |
| 306 | + workspace_id=workspace_id, |
| 307 | + fallback_base_url=fallback_base_url, |
| 308 | + feature_flags=flags, |
| 309 | + ) |
0 commit comments