diff --git a/cmd/greyproxy/program.go b/cmd/greyproxy/program.go index d15a66c..795734d 100644 --- a/cmd/greyproxy/program.go +++ b/cmd/greyproxy/program.go @@ -15,6 +15,7 @@ import ( "runtime" "strconv" "strings" + "time" "syscall" "github.com/andybalholm/brotli" @@ -43,8 +44,9 @@ type program struct { srvGreyproxy *greyproxy.Service srvProfiling *http.Server - cancel context.CancelFunc - assemblerCancel context.CancelFunc + cancel context.CancelFunc + assemblerCancel context.CancelFunc + credStoreCancel context.CancelFunc } func (p *program) initParser() { @@ -204,6 +206,9 @@ func (p *program) Stop(s service.Service) error { p.srvProfiling.Close() logger.Default().Debug("service @profiling shutdown") } + if p.credStoreCancel != nil { + p.credStoreCancel() + } if p.assemblerCancel != nil { p.assemblerCancel() } @@ -322,6 +327,44 @@ func (p *program) buildGreyproxyService() error { gostx.SetGlobalMitmEnabled(enabled) }) + // Initialize credential substitution encryption key and store + encKey, newKey, err := greyproxy.LoadOrGenerateKey(greyproxyDataHome()) + if err != nil { + log.Warnf("credential substitution disabled: %v", err) + } else { + shared.EncryptionKey = encKey + credStore, err := greyproxy.NewCredentialStore(shared.DB, encKey, shared.Bus) + if err != nil { + log.Warnf("credential store init failed: %v", err) + } else { + shared.CredentialStore = credStore + if newKey { + if sessions, globals, err := credStore.PurgeUnreadableCredentials(); err == nil && (sessions > 0 || globals > 0) { + log.Infof("purged %d sessions and %d global credentials (new encryption key)", sessions, globals) + } + } + credStoreCtx, credStoreCancel := context.WithCancel(context.Background()) + p.credStoreCancel = credStoreCancel + credStore.StartCleanupLoop(credStoreCtx, 60*time.Second) + // Wire credential substitution into the MITM pipeline + gostx.SetGlobalCredentialSubstituter(func(req *http.Request) *gostx.CredentialSubstitutionInfo { + result := credStore.SubstituteRequest(req) + if result.Count == 0 { + return nil + } + var sessionID string + if len(result.SessionIDs) > 0 { + sessionID = result.SessionIDs[0] + } + return &gostx.CredentialSubstitutionInfo{ + Labels: result.Labels, + SessionID: sessionID, + } + }) + log.Infof("credential store loaded: %d mappings from %d sessions", credStore.Size(), credStore.SessionCount()) + } + } + shared.Version = version // Collect listening ports for the health endpoint @@ -375,20 +418,22 @@ func (p *program) buildGreyproxyService() error { redactedRespHeaders := redactor.Redact(info.ResponseHeaders) txn, err := greyproxy.CreateHttpTransaction(shared.DB, greyproxy.HttpTransactionCreateInput{ - ContainerName: containerName, - DestinationHost: host, - DestinationPort: port, - Method: info.Method, - URL: "https://" + info.Host + info.URI, - RequestHeaders: redactedReqHeaders, - RequestBody: reqBody, - RequestContentType: reqCT, - StatusCode: info.StatusCode, - ResponseHeaders: redactedRespHeaders, - ResponseBody: respBody, - ResponseContentType: respCT, - DurationMs: info.DurationMs, - Result: "auto", + ContainerName: containerName, + DestinationHost: host, + DestinationPort: port, + Method: info.Method, + URL: "https://" + info.Host + info.URI, + RequestHeaders: redactedReqHeaders, + RequestBody: reqBody, + RequestContentType: reqCT, + StatusCode: info.StatusCode, + ResponseHeaders: redactedRespHeaders, + ResponseBody: respBody, + ResponseContentType: respCT, + DurationMs: info.DurationMs, + Result: "auto", + SubstitutedCredentials: info.SubstitutedCredentials, + SessionID: info.SessionID, }) if err != nil { log.Warnf("failed to store HTTP transaction: %v", err) diff --git a/docs/credential-substitution-api.md b/docs/credential-substitution-api.md new file mode 100644 index 0000000..ad72d0f --- /dev/null +++ b/docs/credential-substitution-api.md @@ -0,0 +1,163 @@ +# Credential Substitution API + +This document describes the REST API that greywall (or any sandbox client) uses to register credential substitution sessions with greyproxy. + +## Overview + +When greywall launches a sandboxed process, it: + +1. Reads the process's environment variables for sensitive values (API keys, tokens, etc.) +2. Generates opaque placeholder strings for each credential +3. Passes the placeholders to the sandboxed process (via modified env vars) +4. Registers a session with greyproxy, providing the placeholder-to-real-value mappings + +GreyProxy then transparently replaces placeholders with real credentials in HTTP headers and query parameters before forwarding requests upstream. + +## Session Lifecycle + +### Create or Update Session + +``` +POST /api/sessions +Content-Type: application/json +``` + +**Request body:** + +```json +{ + "session_id": "uuid-string", + "container_name": "opencode", + "mappings": { + "greyproxy:credential:v1:SESSION_ID:HEX": "sk-real-api-key-value", + "greyproxy:credential:v1:SESSION_ID:HEX2": "another-real-key" + }, + "labels": { + "greyproxy:credential:v1:SESSION_ID:HEX": "OPENAI_API_KEY", + "greyproxy:credential:v1:SESSION_ID:HEX2": "ANTHROPIC_API_KEY" + }, + "metadata": { + "pwd": "/home/user/project", + "cmd": "opencode", + "args": "--model claude-sonnet-4-20250514", + "binary_path": "/usr/bin/opencode", + "pid": "12345" + }, + "ttl_seconds": 900 +} +``` + +**Fields:** + +| Field | Type | Required | Description | +|---|---|---|---| +| `session_id` | string | Yes | Unique session identifier (UUID recommended). Used for upserts. | +| `container_name` | string | Yes | Name of the sandboxed container/process. Used for log correlation. | +| `mappings` | map[string]string | Yes | Placeholder string to real credential value. Keys must use the `greyproxy:credential:` prefix format. | +| `labels` | map[string]string | No | Placeholder string to human-readable label (e.g. env var name). Same keys as `mappings`. | +| `metadata` | map[string]string | No | Arbitrary key-value metadata about the session. Displayed in the dashboard. | +| `ttl_seconds` | int | No | Session TTL in seconds (default: 900, max: 3600). | + +**Response (200):** + +```json +{ + "session_id": "uuid-string", + "expires_at": "2026-03-25T16:00:00Z", + "credential_count": 2 +} +``` + +### Heartbeat + +Reset the TTL for an active session. Call this periodically to keep the session alive. + +``` +POST /api/sessions/:id/heartbeat +``` + +**Response (200):** + +```json +{ + "session_id": "uuid-string", + "expires_at": "2026-03-25T16:15:00Z" +} +``` + +**Response (404):** Session not found or expired. + +### Delete Session + +Immediately expire and remove a session. + +``` +DELETE /api/sessions/:id +``` + +**Response (200):** + +```json +{ + "session_id": "uuid-string", + "deleted": true +} +``` + +### List Sessions + +Returns all active (non-expired) sessions. + +``` +GET /api/sessions +``` + +**Response (200):** Array of session objects with credential labels, counts, metadata, and timestamps. + +## Metadata Convention + +The `metadata` field is a flexible string map. Greywall can send any keys it finds useful. The following keys are recognized and displayed prominently in the dashboard: + +| Key | Description | Example | +|---|---|---| +| `pwd` | Working directory of the sandboxed process | `/home/user/project` | +| `cmd` | Command name | `opencode` | +| `args` | Command arguments | `--model claude-sonnet-4-20250514` | +| `binary_path` | Absolute path to the binary | `/usr/bin/opencode` | +| `pid` | PID of the greywall sandbox process | `12345` | +| `created_by` | What created the session | `greywall v0.2.0` | + +## Placeholder Format + +Placeholders follow this format: + +``` +greyproxy:credential:v1:: +``` + +- `v1` is the version prefix +- `` is either a session ID or `"global"` for global credentials +- `` is a random hex string for uniqueness + +The client can generate these using `GeneratePlaceholder()` or construct them manually. The only requirement is that they start with `greyproxy:credential:` so the proxy's fast-path check can skip scanning headers that don't contain any placeholders. + +## What Gets Substituted + +The proxy scans and substitutes placeholders in: + +- **HTTP request headers** (all header values) +- **URL query parameters** + +It does **not** substitute in: + +- Request bodies (the body is stored as-is with placeholders visible) +- Response data + +## Transaction Tracking + +When credentials are substituted in a request, the resulting HTTP transaction is tagged with: + +- `substituted_credentials`: JSON array of credential label names that were substituted +- `session_id`: The session that provided the credentials + +These fields are visible in the transaction detail view and can be used for filtering via `GET /api/transactions?session_id=...`. diff --git a/docs/credential-substitution.md b/docs/credential-substitution.md new file mode 100644 index 0000000..ba91bab --- /dev/null +++ b/docs/credential-substitution.md @@ -0,0 +1,154 @@ +# Credential Substitution + +Greyproxy provides transparent credential substitution for sandboxed environments. Real API keys are replaced with opaque placeholders inside the sandbox; greyproxy injects the real values into HTTP requests before forwarding them upstream. + +## Overview + +``` + Sandbox Greyproxy Upstream ++--------+ +-----------+ +---------+ +| App | -- Bearer -->| Substitute| -- Bearer -->| API | +| | | ph -> real | | | ++--------+ +-----------+ +---------+ +``` + +The sandboxed process never sees real credentials. Greyproxy holds the mapping and performs substitution at the MITM layer. + +## Two types of credentials + +### Session credentials (automatic) + +When greywall launches a sandbox, it detects credential-like environment variables, generates placeholders, and registers them with greyproxy via `POST /api/sessions`. These credentials are tied to the session lifetime. + +### Global credentials (stored in dashboard) + +Global credentials are stored persistently in the greyproxy dashboard (Settings > Credentials). They are not injected automatically; greywall must explicitly request them using the `--inject` flag. + +```bash +greywall --inject ANTHROPIC_API_KEY -- opencode +``` + +At session creation, greywall sends the requested labels via the `global_credentials` field. Greyproxy resolves each label to its stored placeholder and returns the mappings. Greywall then sets these as environment variables in the sandbox. + +## Session API + +### Create session + +``` +POST /api/sessions +``` + +```json +{ + "session_id": "gw-abc123", + "container_name": "opencode", + "mappings": { + "greyproxy:credential:v1:gw-abc123:aabb...": "sk-real-key" + }, + "labels": { + "greyproxy:credential:v1:gw-abc123:aabb...": "ANTHROPIC_API_KEY" + }, + "global_credentials": ["OPENAI_API_KEY"], + "ttl_seconds": 900 +} +``` + +- `mappings`: placeholder-to-real-value pairs (for session credentials detected by greywall) +- `labels`: placeholder-to-label pairs (for display in the dashboard) +- `global_credentials`: list of global credential labels to resolve and merge into the session +- `ttl_seconds`: session lifetime (max 3600, default 900) + +Either `mappings` or `global_credentials` (or both) must be provided. + +**Response:** + +```json +{ + "session_id": "gw-abc123", + "expires_at": "2026-03-25T23:15:00Z", + "credential_count": 2, + "global_credentials": { + "OPENAI_API_KEY": "greyproxy:credential:v1:global:ccdd..." + } +} +``` + +The `global_credentials` field in the response maps each requested label to its placeholder. Greywall uses these to set environment variables in the sandbox. + +### Heartbeat + +``` +POST /api/sessions/:id/heartbeat +``` + +Resets the session TTL. Returns 404 if the session has expired (greywall will re-register). + +### Delete session + +``` +DELETE /api/sessions/:id +``` + +Removes the session and its credentials from memory. + +### List sessions + +``` +GET /api/sessions +``` + +Returns all active sessions with credential labels and substitution counts. + +## Global credentials API + +### List + +``` +GET /api/credentials +``` + +Returns all global credentials with labels, placeholders, and value previews (never the real value). + +### Create + +``` +POST /api/credentials +``` + +```json +{ + "label": "ANTHROPIC_API_KEY", + "value": "sk-ant-real-secret" +} +``` + +The value is encrypted at rest with a per-installation key (`session.key`). + +### Delete + +``` +DELETE /api/credentials/:id +``` + +## Substitution behavior + +When a request passes through the MITM layer, greyproxy scans HTTP headers and URL query parameters for strings matching the placeholder prefix (`greyproxy:credential:v1:`). Every occurrence is replaced with the corresponding real value. + +- **All occurrences** are replaced, not just the first match +- Substitution applies to both session and global credentials in the same pass +- Substitution happens after headers are cloned for storage, so the dashboard never shows real values +- Request bodies are NOT scanned (most APIs accept credentials via headers) + +## Tracking + +Each substitution increments a counter on the session. Counts are flushed to the database every 60 seconds and broadcast via WebSocket (`session.substitution` event) so the dashboard updates in real time. + +In the Activity view, requests that had credentials substituted show a shield icon. Expanding the row shows which credential labels were involved. + +## Dashboard UI + +The Settings > Credentials tab shows: + +- **Protection status**: whether HTTP and HTTPS traffic are protected (HTTPS requires TLS interception) +- **Global credentials**: stored credentials with add/delete controls and usage instructions +- **Active sessions**: currently registered sessions with credential labels, substitution counts, creation time, and active duration diff --git a/internal/gostx/internal/util/sniffing/sniffer.go b/internal/gostx/internal/util/sniffing/sniffer.go index 2a1f9cd..cdf34db 100644 --- a/internal/gostx/internal/util/sniffing/sniffer.go +++ b/internal/gostx/internal/util/sniffing/sniffer.go @@ -105,17 +105,19 @@ func WithLog(log logger.Logger) HandleOption { // HTTPRoundTripInfo contains decrypted HTTP request/response data from a MITM round-trip. type HTTPRoundTripInfo struct { - Host string - Method string - URI string - Proto string - StatusCode int - RequestHeaders http.Header - RequestBody []byte - ResponseHeaders http.Header - ResponseBody []byte - ContainerName string - DurationMs int64 + Host string + Method string + URI string + Proto string + StatusCode int + RequestHeaders http.Header + RequestBody []byte + ResponseHeaders http.Header + ResponseBody []byte + ContainerName string + DurationMs int64 + SubstitutedCredentials []string + SessionID string } // GlobalHTTPRoundTripHook is called (if set) after each MITM-intercepted HTTP round-trip. @@ -143,6 +145,40 @@ type HTTPRequestHoldInfo struct { // Return nil to allow, ErrRequestDenied to send 403, or block until approval. var GlobalHTTPRequestHoldHook func(ctx context.Context, info HTTPRequestHoldInfo) error +// CredentialSubstitutionInfo holds the result of a credential substitution pass. +type CredentialSubstitutionInfo struct { + Labels []string + SessionID string +} + +// credentialSubstituterType is the function signature for the credential substituter hook. +type credentialSubstituterType = func(req *http.Request) *CredentialSubstitutionInfo + +// globalCredentialSubstituter is called (if set) just before forwarding a request upstream. +// It modifies the request in-place, replacing credential placeholders with real values. +// Headers should already be cloned for storage before this point. +// Returns substitution info (labels and session ID) or nil if nothing was substituted. +// Access is synchronized via atomic.Pointer to prevent races between setup and request handling. +var globalCredentialSubstituter atomic.Pointer[credentialSubstituterType] + +// SetGlobalCredentialSubstituter atomically sets the credential substitution hook. +func SetGlobalCredentialSubstituter(hook credentialSubstituterType) { + if hook == nil { + globalCredentialSubstituter.Store(nil) + } else { + globalCredentialSubstituter.Store(&hook) + } +} + +// getGlobalCredentialSubstituter atomically loads the credential substitution hook. +func getGlobalCredentialSubstituter() credentialSubstituterType { + p := globalCredentialSubstituter.Load() + if p == nil { + return nil + } + return *p +} + // globalMitmEnabled controls whether MITM TLS interception is active. Default: enabled (1). var globalMitmEnabled atomic.Int32 @@ -467,6 +503,13 @@ func (h *Sniffer) httpRoundTrip(ctx context.Context, rw, cc io.ReadWriteCloser, } } + // Credential substitution: swap placeholders with real credentials. + // Headers were already cloned for storage, so this only affects the upstream request. + var subInfo *CredentialSubstitutionInfo + if credSub := getGlobalCredentialSubstituter(); credSub != nil { + subInfo = credSub(req) + } + err = req.Write(cc) if reqBody != nil { @@ -562,6 +605,10 @@ func (h *Sniffer) httpRoundTrip(ctx context.Context, rw, cc io.ReadWriteCloser, ContainerName: containerName, DurationMs: time.Since(ro.Time).Milliseconds(), } + if subInfo != nil { + info.SubstitutedCredentials = subInfo.Labels + info.SessionID = subInfo.SessionID + } if reqBody != nil { info.RequestBody = reqBody.Content() } @@ -1176,6 +1223,13 @@ func (h *h2Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } + // Credential substitution: swap placeholders with real credentials (HTTP/2 path). + // Headers were already cloned for storage, so this only affects the upstream request. + var subInfo *CredentialSubstitutionInfo + if credSub := getGlobalCredentialSubstituter(); credSub != nil { + subInfo = credSub(req) + } + resp, err := h.transport.RoundTrip(req.WithContext(r.Context())) if reqBody != nil { ro.HTTP.Request.Body = reqBody.Content() @@ -1236,6 +1290,10 @@ func (h *h2Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { ContainerName: containerName, DurationMs: time.Since(ro.Time).Milliseconds(), } + if subInfo != nil { + info.SubstitutedCredentials = subInfo.Labels + info.SessionID = subInfo.SessionID + } if reqBody != nil { info.RequestBody = reqBody.Content() } diff --git a/internal/gostx/mitm_hook.go b/internal/gostx/mitm_hook.go index 13594d4..0da048c 100644 --- a/internal/gostx/mitm_hook.go +++ b/internal/gostx/mitm_hook.go @@ -10,19 +10,25 @@ import ( // MitmRoundTripInfo contains decrypted HTTP request/response data from a MITM round-trip. // This re-exports the internal sniffing type for use outside the gostx/internal package. type MitmRoundTripInfo struct { - Host string - Method string - URI string - Proto string - StatusCode int - RequestHeaders http.Header - RequestBody []byte - ResponseHeaders http.Header - ResponseBody []byte - ContainerName string - DurationMs int64 + Host string + Method string + URI string + Proto string + StatusCode int + RequestHeaders http.Header + RequestBody []byte + ResponseHeaders http.Header + ResponseBody []byte + ContainerName string + DurationMs int64 + SubstitutedCredentials []string + SessionID string } +// CredentialSubstitutionInfo holds the result of a credential substitution pass. +// Re-exports the internal sniffing type. +type CredentialSubstitutionInfo = sniffing.CredentialSubstitutionInfo + // MitmRequestHoldInfo contains request details for the hold hook to evaluate. type MitmRequestHoldInfo struct { Host string @@ -69,21 +75,32 @@ func SetGlobalMitmHook(hook func(info MitmRoundTripInfo)) { } sniffing.GlobalHTTPRoundTripHook = func(info sniffing.HTTPRoundTripInfo) { hook(MitmRoundTripInfo{ - Host: info.Host, - Method: info.Method, - URI: info.URI, - Proto: info.Proto, - StatusCode: info.StatusCode, - RequestHeaders: info.RequestHeaders, - RequestBody: info.RequestBody, - ResponseHeaders: info.ResponseHeaders, - ResponseBody: info.ResponseBody, - ContainerName: info.ContainerName, - DurationMs: info.DurationMs, + Host: info.Host, + Method: info.Method, + URI: info.URI, + Proto: info.Proto, + StatusCode: info.StatusCode, + RequestHeaders: info.RequestHeaders, + RequestBody: info.RequestBody, + ResponseHeaders: info.ResponseHeaders, + ResponseBody: info.ResponseBody, + ContainerName: info.ContainerName, + DurationMs: info.DurationMs, + SubstitutedCredentials: info.SubstitutedCredentials, + SessionID: info.SessionID, }) } } +// SetGlobalCredentialSubstituter sets a callback that modifies HTTP requests in-place +// just before they are forwarded upstream. Used to replace credential placeholders +// with real values. Headers are already cloned for storage before this point. +// The hook returns substitution info (labels and session ID). +// Access is synchronized via atomic.Pointer (safe to call while requests are in flight). +func SetGlobalCredentialSubstituter(hook func(req *http.Request) *CredentialSubstitutionInfo) { + sniffing.SetGlobalCredentialSubstituter(hook) +} + // SetGlobalMitmHoldHook sets a global callback that fires BEFORE forwarding a MITM-intercepted // HTTP request upstream. Return nil to allow, ErrRequestDenied to deny with 403. // The hook may block (e.g., waiting for user approval). diff --git a/internal/greyproxy/activity.go b/internal/greyproxy/activity.go index d681ccb..e8d50d4 100644 --- a/internal/greyproxy/activity.go +++ b/internal/greyproxy/activity.go @@ -21,11 +21,12 @@ type ActivityItem struct { RuleSummary sql.NullString MitmSkipReason sql.NullString // HTTP-specific fields - Method sql.NullString - URL sql.NullString - StatusCode sql.NullInt64 - DurationMs sql.NullInt64 - ConversationID sql.NullString + Method sql.NullString + URL sql.NullString + StatusCode sql.NullInt64 + DurationMs sql.NullInt64 + ConversationID sql.NullString + SubstitutedCredentials sql.NullString } // ActivityFilter specifies filters for the unified activity query. @@ -66,7 +67,7 @@ func QueryActivity(db *DB, f ActivityFilter) ([]ActivityItem, int, error) { l.resolved_hostname, l.rule_id, r.destination_pattern as rule_summary, l.mitm_skip_reason, NULL as method, NULL as url, NULL as status_code, NULL as duration_ms, - NULL as conversation_id + NULL as conversation_id, NULL as substituted_credentials FROM request_logs l LEFT JOIN rules r ON l.rule_id = r.id WHERE %s`, where) unionParts = append(unionParts, q) @@ -103,7 +104,8 @@ func QueryActivity(db *DB, f ActivityFilter) ([]ActivityItem, int, error) { ORDER BY cl2.timestamp DESC LIMIT 1 ))) as rule_summary, NULL as mitm_skip_reason, - t.method, t.url, t.status_code, t.duration_ms, t.conversation_id + t.method, t.url, t.status_code, t.duration_ms, t.conversation_id, + t.substituted_credentials FROM http_transactions t WHERE %s`, where) unionParts = append(unionParts, q) @@ -145,7 +147,7 @@ func QueryActivity(db *DB, f ActivityFilter) ([]ActivityItem, int, error) { &item.ResolvedHostname, &item.RuleID, &item.RuleSummary, &item.MitmSkipReason, &item.Method, &item.URL, &item.StatusCode, &item.DurationMs, - &item.ConversationID, + &item.ConversationID, &item.SubstitutedCredentials, ) if err != nil { return nil, 0, fmt.Errorf("scan activity: %w", err) diff --git a/internal/greyproxy/api/credentials.go b/internal/greyproxy/api/credentials.go new file mode 100644 index 0000000..b77010b --- /dev/null +++ b/internal/greyproxy/api/credentials.go @@ -0,0 +1,94 @@ +package api + +import ( + "net/http" + + "github.com/gin-gonic/gin" + greyproxy "github.com/greyhavenhq/greyproxy/internal/greyproxy" +) + +// CredentialsListHandler returns all global credentials (labels + previews only). +func CredentialsListHandler(s *Shared) gin.HandlerFunc { + return func(c *gin.Context) { + creds, err := greyproxy.ListGlobalCredentials(s.DB) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + result := make([]greyproxy.GlobalCredentialJSON, 0, len(creds)) + for _, cred := range creds { + result = append(result, cred.ToJSON()) + } + + c.JSON(http.StatusOK, result) + } +} + +// CredentialsCreateHandler registers a new global credential. +func CredentialsCreateHandler(s *Shared) gin.HandlerFunc { + return func(c *gin.Context) { + var input greyproxy.GlobalCredentialCreateInput + if err := c.ShouldBindJSON(&input); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if input.Label == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "label is required"}) + return + } + if input.Value == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "value is required"}) + return + } + + cred, err := greyproxy.CreateGlobalCredential(s.DB, input, s.EncryptionKey) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // Update in-memory store + if s.CredentialStore != nil { + s.CredentialStore.RegisterGlobalCredential(cred.Placeholder, input.Value, input.Label) + } + + c.JSON(http.StatusOK, gin.H{ + "id": cred.ID, + "label": cred.Label, + "placeholder": cred.Placeholder, + "value_preview": cred.ValuePreview, + "created_at": cred.CreatedAt.UTC().Format("2006-01-02T15:04:05Z"), + }) + } +} + +// CredentialsDeleteHandler removes a global credential. +func CredentialsDeleteHandler(s *Shared) gin.HandlerFunc { + return func(c *gin.Context) { + id := c.Param("id") + + // Get the credential first to know its placeholder for memory cleanup + cred, err := greyproxy.GetGlobalCredential(s.DB, id) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "credential not found"}) + return + } + + deleted, err := greyproxy.DeleteGlobalCredential(s.DB, id) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if deleted && s.CredentialStore != nil && cred != nil { + s.CredentialStore.UnregisterGlobalCredential(cred.Placeholder) + } + + c.JSON(http.StatusOK, gin.H{ + "id": id, + "deleted": deleted, + }) + } +} diff --git a/internal/greyproxy/api/router.go b/internal/greyproxy/api/router.go index 79489f4..a3ef92d 100644 --- a/internal/greyproxy/api/router.go +++ b/internal/greyproxy/api/router.go @@ -10,17 +10,19 @@ import ( // Shared holds shared state passed to all handlers. type Shared struct { - DB *greyproxy.DB - Cache *greyproxy.DNSCache - Bus *greyproxy.EventBus - Waiters *greyproxy.WaiterTracker - ConnTracker *greyproxy.ConnTracker - Notifier *greyproxy.Notifier - Settings *greyproxy.SettingsManager - Assembler *greyproxy.ConversationAssembler - Version string - Ports map[string]int - DataHome string // Path to greyproxy data directory (contains CA cert/key) + DB *greyproxy.DB + Cache *greyproxy.DNSCache + Bus *greyproxy.EventBus + Waiters *greyproxy.WaiterTracker + ConnTracker *greyproxy.ConnTracker + Notifier *greyproxy.Notifier + Settings *greyproxy.SettingsManager + Assembler *greyproxy.ConversationAssembler + CredentialStore *greyproxy.CredentialStore + EncryptionKey []byte + Version string + Ports map[string]int + DataHome string // Path to greyproxy data directory (contains CA cert/key) } // NewRouter creates the Gin router with all routes. @@ -94,6 +96,17 @@ func NewRouter(s *Shared, pathPrefix string) (*gin.Engine, *gin.RouterGroup) { api.POST("/maintenance/rebuild-conversations", RebuildConversationsHandler(s)) api.POST("/maintenance/redact-headers", RedactHeadersHandler(s)) api.GET("/maintenance/status", MaintenanceStatusHandler(s)) + + // Credential substitution sessions + api.GET("/sessions", SessionsListHandler(s)) + api.POST("/sessions", SessionsCreateHandler(s)) + api.POST("/sessions/:id/heartbeat", SessionsHeartbeatHandler(s)) + api.DELETE("/sessions/:id", SessionsDeleteHandler(s)) + + // Global credentials + api.GET("/credentials", CredentialsListHandler(s)) + api.POST("/credentials", CredentialsCreateHandler(s)) + api.DELETE("/credentials/:id", CredentialsDeleteHandler(s)) } // WebSocket diff --git a/internal/greyproxy/api/sessions.go b/internal/greyproxy/api/sessions.go new file mode 100644 index 0000000..c6a15c7 --- /dev/null +++ b/internal/greyproxy/api/sessions.go @@ -0,0 +1,170 @@ +package api + +import ( + "fmt" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + greyproxy "github.com/greyhavenhq/greyproxy/internal/greyproxy" +) + +// SessionsListHandler returns all active sessions (without credential values). +func SessionsListHandler(s *Shared) gin.HandlerFunc { + return func(c *gin.Context) { + sessions, err := greyproxy.ListSessions(s.DB) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + result := make([]greyproxy.SessionJSON, 0, len(sessions)) + for _, sess := range sessions { + labels := greyproxy.GetSessionLabels(&sess) + result = append(result, sess.ToJSON(labels)) + } + + c.JSON(http.StatusOK, result) + } +} + +// SessionsCreateHandler creates or upserts a credential substitution session. +// +// If `global_credentials` is provided (list of labels), the handler resolves +// each label to its stored placeholder and includes it in the response. +// Greywall uses the returned placeholders to set environment variables and +// rewrite .env files in the sandbox. The placeholder-to-real-value mapping +// is merged into the session so the proxy can substitute on the wire. +func SessionsCreateHandler(s *Shared) gin.HandlerFunc { + return func(c *gin.Context) { + var input greyproxy.SessionCreateInput + if err := c.ShouldBindJSON(&input); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if input.SessionID == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "session_id is required"}) + return + } + if input.ContainerName == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "container_name is required"}) + return + } + + // Resolve global credentials: validate they exist and merge labels for dashboard display. + // Global credential values are NOT duplicated into session mappings; the proxy + // loads them separately from the global_credentials table at startup and when + // credentials are created/deleted. This ensures deleting a global credential + // immediately stops substitution for all sessions. + var resolvedGlobals map[string]string // label -> placeholder + if len(input.GlobalCredentials) > 0 { + found, missing, err := greyproxy.GetGlobalCredentialsByLabels(s.DB, input.GlobalCredentials) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if len(missing) > 0 { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("unknown global credentials: %s", strings.Join(missing, ", ")), + }) + return + } + + if input.Labels == nil { + input.Labels = make(map[string]string) + } + resolvedGlobals = make(map[string]string, len(found)) + + for label, cred := range found { + // Only store the label mapping (for dashboard), not the real value + input.Labels[cred.Placeholder] = label + resolvedGlobals[label] = cred.Placeholder + } + } + + if len(input.Mappings) == 0 && len(resolvedGlobals) == 0 { + c.JSON(http.StatusBadRequest, gin.H{"error": "no credentials provided (mappings or global_credentials required)"}) + return + } + + // Cap TTL at 1 hour + maxTTL := 3600 + if input.TTLSeconds > maxTTL { + input.TTLSeconds = maxTTL + } + + session, err := greyproxy.CreateOrUpdateSession(s.DB, input, s.EncryptionKey) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // Update in-memory store + if s.CredentialStore != nil { + s.CredentialStore.RegisterSession(session, input.Mappings) + } + + resp := gin.H{ + "session_id": session.SessionID, + "expires_at": session.ExpiresAt.UTC().Format("2006-01-02T15:04:05Z"), + "credential_count": len(input.Mappings) + len(resolvedGlobals), + } + if resolvedGlobals != nil { + resp["global_credentials"] = resolvedGlobals + } + + c.JSON(http.StatusOK, resp) + } +} + +// SessionsHeartbeatHandler resets the TTL for an active session. +func SessionsHeartbeatHandler(s *Shared) gin.HandlerFunc { + return func(c *gin.Context) { + sessionID := c.Param("id") + + session, err := greyproxy.HeartbeatSession(s.DB, sessionID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if session == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "session not found or expired"}) + return + } + + if s.Bus != nil { + s.Bus.Publish(greyproxy.Event{ + Type: greyproxy.EventSessionHeartbeat, + Data: sessionID, + }) + } + + c.JSON(http.StatusOK, gin.H{ + "session_id": session.SessionID, + "expires_at": session.ExpiresAt.UTC().Format("2006-01-02T15:04:05Z"), + }) + } +} + +// SessionsDeleteHandler removes a session and wipes credentials from DB and memory. +func SessionsDeleteHandler(s *Shared) gin.HandlerFunc { + return func(c *gin.Context) { + sessionID := c.Param("id") + + deleted, err := greyproxy.DeleteSession(s.DB, sessionID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if deleted && s.CredentialStore != nil { + s.CredentialStore.UnregisterSession(sessionID) + } + + c.JSON(http.StatusOK, gin.H{ + "session_id": sessionID, + "deleted": deleted, + }) + } +} diff --git a/internal/greyproxy/api/sessions_test.go b/internal/greyproxy/api/sessions_test.go new file mode 100644 index 0000000..acc4bae --- /dev/null +++ b/internal/greyproxy/api/sessions_test.go @@ -0,0 +1,261 @@ +package api + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + greyproxy "github.com/greyhavenhq/greyproxy/internal/greyproxy" + _ "modernc.org/sqlite" +) + +func testEncryptionKey() []byte { + key := make([]byte, 32) + for i := range key { + key[i] = byte(i) + } + return key +} + +func setupTestSharedWithCreds(t *testing.T) *Shared { + t.Helper() + s := setupTestShared(t) + key := testEncryptionKey() + s.EncryptionKey = key + + store, err := greyproxy.NewCredentialStore(s.DB, key, s.Bus) + if err != nil { + t.Fatal(err) + } + s.CredentialStore = store + return s +} + +func TestSessionsCreate_WithGlobalCredentials(t *testing.T) { + gin.SetMode(gin.TestMode) + s := setupTestSharedWithCreds(t) + key := testEncryptionKey() + + // Create a global credential in the DB + cred, err := greyproxy.CreateGlobalCredential(s.DB, greyproxy.GlobalCredentialCreateInput{ + Label: "ANTHROPIC_API_KEY", + Value: "sk-ant-real-secret", + }, key) + if err != nil { + t.Fatal(err) + } + + router := gin.New() + router.POST("/api/sessions", SessionsCreateHandler(s)) + + body := map[string]any{ + "session_id": "gw-test-global", + "container_name": "sandbox-1", + "global_credentials": []string{"ANTHROPIC_API_KEY"}, + "ttl_seconds": 300, + } + b, _ := json.Marshal(body) + req := httptest.NewRequest("POST", "/api/sessions", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("status = %d, want 200; body = %s", w.Code, w.Body.String()) + } + + var resp map[string]any + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatal(err) + } + + // Should have 1 credential (the global one) + if count, ok := resp["credential_count"].(float64); !ok || int(count) != 1 { + t.Errorf("credential_count = %v, want 1", resp["credential_count"]) + } + + // Should return the resolved global credentials + globals, ok := resp["global_credentials"].(map[string]any) + if !ok { + t.Fatalf("global_credentials missing or wrong type: %v", resp["global_credentials"]) + } + placeholder, ok := globals["ANTHROPIC_API_KEY"].(string) + if !ok || placeholder == "" { + t.Fatalf("ANTHROPIC_API_KEY placeholder missing: %v", globals) + } + if placeholder != cred.Placeholder { + t.Errorf("placeholder = %q, want %q", placeholder, cred.Placeholder) + } + + // Verify the session was stored WITHOUT the global credential value in mappings + // (global credentials are resolved at substitution time from the global store) + session, err := greyproxy.GetSession(s.DB, "gw-test-global") + if err != nil { + t.Fatal(err) + } + mappings, err := greyproxy.DecryptSessionMappings(session, key) + if err != nil { + t.Fatal(err) + } + if _, ok := mappings[cred.Placeholder]; ok { + t.Error("global credential value should NOT be duplicated into session mappings") + } + + // Verify labels still contain the global credential label (for dashboard display) + labels := greyproxy.GetSessionLabels(session) + if labels[cred.Placeholder] != "ANTHROPIC_API_KEY" { + t.Errorf("label = %q, want %q", labels[cred.Placeholder], "ANTHROPIC_API_KEY") + } +} + +func TestSessionsCreate_MixedCredentials(t *testing.T) { + gin.SetMode(gin.TestMode) + s := setupTestSharedWithCreds(t) + key := testEncryptionKey() + + // Create a global credential + _, err := greyproxy.CreateGlobalCredential(s.DB, greyproxy.GlobalCredentialCreateInput{ + Label: "GLOBAL_KEY", + Value: "sk-global-value", + }, key) + if err != nil { + t.Fatal(err) + } + + // Create session with both session-specific mappings and global credentials + sessionPlaceholder := "greyproxy:credential:v1:gw-mixed:aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa1" + body := map[string]any{ + "session_id": "gw-mixed", + "container_name": "sandbox", + "mappings": map[string]string{ + sessionPlaceholder: "sk-session-value", + }, + "labels": map[string]string{ + sessionPlaceholder: "SESSION_KEY", + }, + "global_credentials": []string{"GLOBAL_KEY"}, + "ttl_seconds": 300, + } + b, _ := json.Marshal(body) + + router := gin.New() + router.POST("/api/sessions", SessionsCreateHandler(s)) + + req := httptest.NewRequest("POST", "/api/sessions", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("status = %d, want 200; body = %s", w.Code, w.Body.String()) + } + + var resp map[string]any + json.Unmarshal(w.Body.Bytes(), &resp) + + // Should have 2 credentials total + if count := int(resp["credential_count"].(float64)); count != 2 { + t.Errorf("credential_count = %d, want 2", count) + } +} + +func TestSessionsCreate_UnknownGlobalCredential(t *testing.T) { + gin.SetMode(gin.TestMode) + s := setupTestSharedWithCreds(t) + + body := map[string]any{ + "session_id": "gw-fail", + "container_name": "sandbox", + "global_credentials": []string{"NONEXISTENT_KEY"}, + "ttl_seconds": 300, + } + b, _ := json.Marshal(body) + + router := gin.New() + router.POST("/api/sessions", SessionsCreateHandler(s)) + + req := httptest.NewRequest("POST", "/api/sessions", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want 400; body = %s", w.Code, w.Body.String()) + } + + var resp map[string]any + json.Unmarshal(w.Body.Bytes(), &resp) + errMsg, _ := resp["error"].(string) + if errMsg == "" { + t.Error("expected error message") + } +} + +func TestSessionsCreate_OnlyGlobalNoMappings(t *testing.T) { + gin.SetMode(gin.TestMode) + s := setupTestSharedWithCreds(t) + key := testEncryptionKey() + + // Create a global credential + _, err := greyproxy.CreateGlobalCredential(s.DB, greyproxy.GlobalCredentialCreateInput{ + Label: "ONLY_GLOBAL", + Value: "sk-only-global", + }, key) + if err != nil { + t.Fatal(err) + } + + // Session with only global credentials, no explicit mappings + body := map[string]any{ + "session_id": "gw-global-only", + "container_name": "sandbox", + "global_credentials": []string{"ONLY_GLOBAL"}, + "ttl_seconds": 300, + } + b, _ := json.Marshal(body) + + router := gin.New() + router.POST("/api/sessions", SessionsCreateHandler(s)) + + req := httptest.NewRequest("POST", "/api/sessions", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("status = %d, want 200; body = %s", w.Code, w.Body.String()) + } + + var resp map[string]any + json.Unmarshal(w.Body.Bytes(), &resp) + if count := int(resp["credential_count"].(float64)); count != 1 { + t.Errorf("credential_count = %d, want 1", count) + } +} + +func TestSessionsCreate_NoCredentialsAtAll(t *testing.T) { + gin.SetMode(gin.TestMode) + s := setupTestSharedWithCreds(t) + + body := map[string]any{ + "session_id": "gw-empty", + "container_name": "sandbox", + "ttl_seconds": 300, + } + b, _ := json.Marshal(body) + + router := gin.New() + router.POST("/api/sessions", SessionsCreateHandler(s)) + + req := httptest.NewRequest("POST", "/api/sessions", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want 400; body = %s", w.Code, w.Body.String()) + } +} diff --git a/internal/greyproxy/api/transactions.go b/internal/greyproxy/api/transactions.go index 6ef121a..0fcc017 100644 --- a/internal/greyproxy/api/transactions.go +++ b/internal/greyproxy/api/transactions.go @@ -19,6 +19,7 @@ func TransactionsListHandler(s *Shared) gin.HandlerFunc { Container: c.Query("container"), Destination: c.Query("destination"), Method: c.Query("method"), + SessionID: c.Query("session_id"), Limit: limit, Offset: offset, } diff --git a/internal/greyproxy/credential_crud.go b/internal/greyproxy/credential_crud.go new file mode 100644 index 0000000..dd80772 --- /dev/null +++ b/internal/greyproxy/credential_crud.go @@ -0,0 +1,425 @@ +package greyproxy + +import ( + "crypto/rand" + "encoding/hex" + "encoding/json" + "fmt" +) + +// --- Sessions --- + +type SessionCreateInput struct { + SessionID string `json:"session_id"` + ContainerName string `json:"container_name"` + Mappings map[string]string `json:"mappings"` + Labels map[string]string `json:"labels"` + Metadata map[string]string `json:"metadata"` + TTLSeconds int `json:"ttl_seconds"` + GlobalCredentials []string `json:"global_credentials,omitempty"` +} + +// CreateOrUpdateSession creates or upserts a credential substitution session. +// Mappings are encrypted before storage. Returns the created/updated session. +func CreateOrUpdateSession(db *DB, input SessionCreateInput, encryptionKey []byte) (*Session, error) { + db.Lock() + defer db.Unlock() + + if input.TTLSeconds <= 0 { + input.TTLSeconds = 900 + } + + mappingsJSON, err := json.Marshal(input.Mappings) + if err != nil { + return nil, fmt.Errorf("marshal mappings: %w", err) + } + + mappingsEnc, err := Encrypt(encryptionKey, mappingsJSON) + if err != nil { + return nil, fmt.Errorf("encrypt mappings: %w", err) + } + + labelsJSON, err := json.Marshal(input.Labels) + if err != nil { + return nil, fmt.Errorf("marshal labels: %w", err) + } + + metadata := input.Metadata + if metadata == nil { + metadata = make(map[string]string) + } + metadataJSON, err := json.Marshal(metadata) + if err != nil { + return nil, fmt.Errorf("marshal metadata: %w", err) + } + + _, err = db.WriteDB().Exec( + `INSERT INTO sessions (session_id, container_name, mappings_enc, labels_json, metadata_json, ttl_seconds, created_at, expires_at, last_heartbeat) + VALUES (?, ?, ?, ?, ?, ?, datetime('now'), datetime('now', '+' || ? || ' seconds'), datetime('now')) + ON CONFLICT(session_id) DO UPDATE SET + container_name = excluded.container_name, + mappings_enc = excluded.mappings_enc, + labels_json = excluded.labels_json, + metadata_json = excluded.metadata_json, + ttl_seconds = excluded.ttl_seconds, + expires_at = excluded.expires_at, + last_heartbeat = excluded.last_heartbeat`, + input.SessionID, input.ContainerName, mappingsEnc, string(labelsJSON), string(metadataJSON), + input.TTLSeconds, input.TTLSeconds, + ) + if err != nil { + return nil, fmt.Errorf("upsert session: %w", err) + } + + // Re-read from DB to get the canonical timestamps + return getSessionLocked(db, input.SessionID) +} + +// HeartbeatSession resets the TTL for an active session. +// Returns the updated session or nil if not found/expired. +func HeartbeatSession(db *DB, sessionID string) (*Session, error) { + db.Lock() + defer db.Unlock() + + // Update expires_at and last_heartbeat only if the session is still active + result, err := db.WriteDB().Exec( + `UPDATE sessions SET + expires_at = datetime('now', '+' || ttl_seconds || ' seconds'), + last_heartbeat = datetime('now') + WHERE session_id = ? AND expires_at > datetime('now')`, + sessionID, + ) + if err != nil { + return nil, fmt.Errorf("heartbeat session: %w", err) + } + n, _ := result.RowsAffected() + if n == 0 { + return nil, nil // not found or expired + } + + return getSessionLocked(db, sessionID) +} + +// DeleteSession removes a session from the database. +func DeleteSession(db *DB, sessionID string) (bool, error) { + db.Lock() + defer db.Unlock() + + result, err := db.WriteDB().Exec("DELETE FROM sessions WHERE session_id = ?", sessionID) + if err != nil { + return false, fmt.Errorf("delete session: %w", err) + } + n, _ := result.RowsAffected() + return n > 0, nil +} + +// GetSession retrieves a session by ID. +func GetSession(db *DB, sessionID string) (*Session, error) { + return scanSession(db.ReadDB().QueryRow( + `SELECT session_id, container_name, mappings_enc, labels_json, metadata_json, ttl_seconds, + created_at, expires_at, last_heartbeat, substitution_count + FROM sessions WHERE session_id = ?`, sessionID, + )) +} + +// ListSessions returns all active (non-expired) sessions. +func ListSessions(db *DB) ([]Session, error) { + rows, err := db.ReadDB().Query( + `SELECT session_id, container_name, mappings_enc, labels_json, metadata_json, ttl_seconds, + created_at, expires_at, last_heartbeat, substitution_count + FROM sessions WHERE expires_at > datetime('now') ORDER BY created_at DESC`, + ) + if err != nil { + return nil, fmt.Errorf("list sessions: %w", err) + } + defer rows.Close() + + var sessions []Session + for rows.Next() { + s, err := scanSessionRow(rows) + if err != nil { + return nil, err + } + sessions = append(sessions, *s) + } + return sessions, rows.Err() +} + +// DeleteExpiredSessions removes all expired sessions and returns their IDs. +// Uses a single snapshot timestamp to avoid a race where a heartbeat between +// the SELECT and DELETE could extend a session that was already marked expired. +func DeleteExpiredSessions(db *DB) ([]string, error) { + db.Lock() + defer db.Unlock() + + // Snapshot the current time once so both queries use the same cutoff. + var now string + if err := db.WriteDB().QueryRow("SELECT datetime('now')").Scan(&now); err != nil { + return nil, fmt.Errorf("get current time: %w", err) + } + + rows, err := db.WriteDB().Query( + "SELECT session_id FROM sessions WHERE expires_at <= ?", now, + ) + if err != nil { + return nil, fmt.Errorf("find expired sessions: %w", err) + } + defer rows.Close() + + var ids []string + for rows.Next() { + var id string + if err := rows.Scan(&id); err != nil { + return nil, fmt.Errorf("scan expired session: %w", err) + } + ids = append(ids, id) + } + if err := rows.Err(); err != nil { + return nil, err + } + + if len(ids) > 0 { + _, err = db.WriteDB().Exec( + "DELETE FROM sessions WHERE expires_at <= ?", now, + ) + if err != nil { + return nil, fmt.Errorf("delete expired sessions: %w", err) + } + } + + return ids, nil +} + +// IncrementSubstitutionCount atomically increments the substitution counter for a session. +func IncrementSubstitutionCount(db *DB, sessionID string, delta int64) error { + db.Lock() + defer db.Unlock() + + _, err := db.WriteDB().Exec( + "UPDATE sessions SET substitution_count = substitution_count + ? WHERE session_id = ?", + delta, sessionID, + ) + return err +} + +// LoadAllSessions returns all sessions (including expired, for startup reload). +func LoadAllSessions(db *DB) ([]Session, error) { + rows, err := db.ReadDB().Query( + `SELECT session_id, container_name, mappings_enc, labels_json, metadata_json, ttl_seconds, + created_at, expires_at, last_heartbeat, substitution_count + FROM sessions ORDER BY created_at`, + ) + if err != nil { + return nil, fmt.Errorf("load all sessions: %w", err) + } + defer rows.Close() + + var sessions []Session + for rows.Next() { + s, err := scanSessionRow(rows) + if err != nil { + return nil, err + } + sessions = append(sessions, *s) + } + return sessions, rows.Err() +} + +// getSessionLocked retrieves a session (caller must hold write lock). +func getSessionLocked(db *DB, sessionID string) (*Session, error) { + return scanSession(db.WriteDB().QueryRow( + `SELECT session_id, container_name, mappings_enc, labels_json, metadata_json, ttl_seconds, + created_at, expires_at, last_heartbeat, substitution_count + FROM sessions WHERE session_id = ?`, sessionID, + )) +} + +type scannable interface { + Scan(dest ...any) error +} + +func scanSession(row scannable) (*Session, error) { + var s Session + err := row.Scan( + &s.SessionID, &s.ContainerName, &s.MappingsEnc, &s.LabelsJSON, &s.MetadataJSON, + &s.TTLSeconds, &s.CreatedAt, &s.ExpiresAt, &s.LastHeartbeat, &s.SubstitutionCount, + ) + if err != nil { + return nil, err + } + return &s, nil +} + +func scanSessionRow(rows scannable) (*Session, error) { + return scanSession(rows) +} + +// DecryptSessionMappings decrypts the encrypted mappings blob. +func DecryptSessionMappings(s *Session, key []byte) (map[string]string, error) { + plaintext, err := Decrypt(key, s.MappingsEnc) + if err != nil { + return nil, fmt.Errorf("decrypt mappings: %w", err) + } + var mappings map[string]string + if err := json.Unmarshal(plaintext, &mappings); err != nil { + return nil, fmt.Errorf("unmarshal mappings: %w", err) + } + return mappings, nil +} + +// ParseSessionLabels parses the labels JSON string. +func ParseSessionLabels(s *Session) (map[string]string, error) { + var labels map[string]string + if err := json.Unmarshal([]byte(s.LabelsJSON), &labels); err != nil { + return nil, fmt.Errorf("unmarshal labels: %w", err) + } + return labels, nil +} + +// --- Global Credentials --- + +type GlobalCredentialCreateInput struct { + Label string `json:"label"` + Value string `json:"value"` +} + +// CreateGlobalCredential creates a new global credential with an auto-generated placeholder. +func CreateGlobalCredential(db *DB, input GlobalCredentialCreateInput, encryptionKey []byte) (*GlobalCredential, error) { + db.Lock() + defer db.Unlock() + + id, err := generateCredentialID() + if err != nil { + return nil, err + } + + placeholder, err := GeneratePlaceholder("global") + if err != nil { + return nil, err + } + + valueEnc, err := Encrypt(encryptionKey, []byte(input.Value)) + if err != nil { + return nil, fmt.Errorf("encrypt value: %w", err) + } + + preview := MaskCredentialValue(input.Value) + + _, err = db.WriteDB().Exec( + `INSERT INTO global_credentials (id, label, placeholder, value_enc, value_preview, created_at) + VALUES (?, ?, ?, ?, ?, datetime('now'))`, + id, input.Label, placeholder, valueEnc, preview, + ) + if err != nil { + return nil, fmt.Errorf("insert global credential: %w", err) + } + + // Re-read to get canonical timestamp + return GetGlobalCredentialLocked(db, id) +} + +// ListGlobalCredentials returns all global credentials (without decrypted values). +func ListGlobalCredentials(db *DB) ([]GlobalCredential, error) { + rows, err := db.ReadDB().Query( + `SELECT id, label, placeholder, value_enc, value_preview, created_at + FROM global_credentials ORDER BY created_at DESC`, + ) + if err != nil { + return nil, fmt.Errorf("list global credentials: %w", err) + } + defer rows.Close() + + var creds []GlobalCredential + for rows.Next() { + var c GlobalCredential + if err := rows.Scan(&c.ID, &c.Label, &c.Placeholder, &c.ValueEnc, &c.ValuePreview, &c.CreatedAt); err != nil { + return nil, fmt.Errorf("scan global credential: %w", err) + } + creds = append(creds, c) + } + return creds, rows.Err() +} + +// GetGlobalCredentialsByLabels retrieves global credentials matching the given labels. +// Returns a map of label -> GlobalCredential for found credentials and a list of missing labels. +func GetGlobalCredentialsByLabels(db *DB, labels []string) (map[string]*GlobalCredential, []string, error) { + if len(labels) == 0 { + return nil, nil, nil + } + + creds, err := ListGlobalCredentials(db) + if err != nil { + return nil, nil, err + } + + byLabel := make(map[string]*GlobalCredential, len(creds)) + for i := range creds { + byLabel[creds[i].Label] = &creds[i] + } + + found := make(map[string]*GlobalCredential, len(labels)) + var missing []string + for _, label := range labels { + if c, ok := byLabel[label]; ok { + found[label] = c + } else { + missing = append(missing, label) + } + } + return found, missing, nil +} + +// GetGlobalCredential retrieves a single global credential by ID. +func GetGlobalCredential(db *DB, id string) (*GlobalCredential, error) { + return scanGlobalCredential(db.ReadDB().QueryRow( + `SELECT id, label, placeholder, value_enc, value_preview, created_at + FROM global_credentials WHERE id = ?`, id, + )) +} + +// GetGlobalCredentialLocked retrieves a credential using the write DB (caller must hold lock). +func GetGlobalCredentialLocked(db *DB, id string) (*GlobalCredential, error) { + return scanGlobalCredential(db.WriteDB().QueryRow( + `SELECT id, label, placeholder, value_enc, value_preview, created_at + FROM global_credentials WHERE id = ?`, id, + )) +} + +func scanGlobalCredential(row scannable) (*GlobalCredential, error) { + var c GlobalCredential + err := row.Scan(&c.ID, &c.Label, &c.Placeholder, &c.ValueEnc, &c.ValuePreview, &c.CreatedAt) + if err != nil { + return nil, err + } + return &c, nil +} + +// DeleteGlobalCredential removes a global credential. +func DeleteGlobalCredential(db *DB, id string) (bool, error) { + db.Lock() + defer db.Unlock() + + result, err := db.WriteDB().Exec("DELETE FROM global_credentials WHERE id = ?", id) + if err != nil { + return false, fmt.Errorf("delete global credential: %w", err) + } + n, _ := result.RowsAffected() + return n > 0, nil +} + +// DecryptGlobalCredentialValue decrypts the encrypted credential value. +func DecryptGlobalCredentialValue(c *GlobalCredential, key []byte) (string, error) { + plaintext, err := Decrypt(key, c.ValueEnc) + if err != nil { + return "", fmt.Errorf("decrypt credential: %w", err) + } + return string(plaintext), nil +} + +func generateCredentialID() (string, error) { + b := make([]byte, 8) + if _, err := rand.Read(b); err != nil { + return "", fmt.Errorf("generate credential ID: %w", err) + } + return "cred_" + hex.EncodeToString(b), nil +} diff --git a/internal/greyproxy/credential_crypto.go b/internal/greyproxy/credential_crypto.go new file mode 100644 index 0000000..57567ee --- /dev/null +++ b/internal/greyproxy/credential_crypto.go @@ -0,0 +1,129 @@ +package greyproxy + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/hex" + "fmt" + "io" + "os" + "path/filepath" +) + +const ( + // sessionKeyFile is the filename for the encryption master key. + sessionKeyFile = "session.key" + // sessionKeySize is the AES-256 key size in bytes. + sessionKeySize = 32 + + // PlaceholderPrefix is the prefix used to identify credential placeholders. + PlaceholderPrefix = "greyproxy:credential:" + // PlaceholderVersion is the current placeholder format version. + PlaceholderVersion = "v1" + // placeholderRandomBytes is the number of random bytes in the hex tail. + placeholderRandomBytes = 16 +) + +// LoadOrGenerateKey loads the master encryption key from dataDir/session.key, +// or generates a new one if the file does not exist. +// Returns the 32-byte key and whether a new key was generated. +// If the file exists but has the wrong size, an error is returned rather than +// silently overwriting (which would make all stored credentials unreadable). +func LoadOrGenerateKey(dataDir string) ([]byte, bool, error) { + keyPath := filepath.Join(dataDir, sessionKeyFile) + + data, err := os.ReadFile(keyPath) + if err == nil { + if len(data) == sessionKeySize { + return data, false, nil + } + return nil, false, fmt.Errorf("encryption key file %s is corrupt (got %d bytes, want %d); "+ + "delete it manually to generate a new key (existing encrypted credentials will be lost)", + keyPath, len(data), sessionKeySize) + } + + if !os.IsNotExist(err) { + return nil, false, fmt.Errorf("read key file: %w", err) + } + + // File does not exist; generate new key + key := make([]byte, sessionKeySize) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + return nil, false, fmt.Errorf("generate key: %w", err) + } + + if err := os.MkdirAll(dataDir, 0o750); err != nil { + return nil, false, fmt.Errorf("create data dir: %w", err) + } + if err := os.WriteFile(keyPath, key, 0o600); err != nil { + return nil, false, fmt.Errorf("write key file: %w", err) + } + + return key, true, nil +} + +// Encrypt encrypts plaintext using AES-256-GCM with the given key. +// Returns nonce (12 bytes) || ciphertext || GCM tag (16 bytes). +func Encrypt(key, plaintext []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, fmt.Errorf("create cipher: %w", err) + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, fmt.Errorf("create GCM: %w", err) + } + + nonce := make([]byte, gcm.NonceSize()) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return nil, fmt.Errorf("generate nonce: %w", err) + } + + return gcm.Seal(nonce, nonce, plaintext, nil), nil +} + +// Decrypt decrypts data encrypted by Encrypt using AES-256-GCM. +// Expects input format: nonce (12 bytes) || ciphertext || GCM tag (16 bytes). +func Decrypt(key, ciphertext []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, fmt.Errorf("create cipher: %w", err) + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, fmt.Errorf("create GCM: %w", err) + } + + nonceSize := gcm.NonceSize() + if len(ciphertext) < nonceSize { + return nil, fmt.Errorf("ciphertext too short") + } + + nonce, ct := ciphertext[:nonceSize], ciphertext[nonceSize:] + return gcm.Open(nil, nonce, ct, nil) +} + +// GeneratePlaceholder creates a credential placeholder string. +// Format: greyproxy:credential:v1::<32_hex_chars> +func GeneratePlaceholder(sessionID string) (string, error) { + b := make([]byte, placeholderRandomBytes) + if _, err := io.ReadFull(rand.Reader, b); err != nil { + return "", fmt.Errorf("generate random: %w", err) + } + return fmt.Sprintf("%s%s:%s:%s", PlaceholderPrefix, PlaceholderVersion, sessionID, hex.EncodeToString(b)), nil +} + +// MaskCredentialValue returns a masked preview of a credential value. +// Shows first 6 + last 3 chars for values >= 9 chars, otherwise masks all but last 2. +func MaskCredentialValue(value string) string { + if len(value) >= 9 { + return value[:6] + "***" + value[len(value)-3:] + } + if len(value) <= 2 { + return "***" + } + return "***" + value[len(value)-2:] +} diff --git a/internal/greyproxy/credential_crypto_test.go b/internal/greyproxy/credential_crypto_test.go new file mode 100644 index 0000000..4cb22dc --- /dev/null +++ b/internal/greyproxy/credential_crypto_test.go @@ -0,0 +1,239 @@ +package greyproxy + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func TestEncryptDecryptRoundTrip(t *testing.T) { + key := make([]byte, sessionKeySize) + for i := range key { + key[i] = byte(i) + } + + plaintext := []byte("sk-ant-api03-real-secret-key-value") + + encrypted, err := Encrypt(key, plaintext) + if err != nil { + t.Fatal(err) + } + + if len(encrypted) <= len(plaintext) { + t.Error("encrypted data should be longer than plaintext (nonce + tag)") + } + + decrypted, err := Decrypt(key, encrypted) + if err != nil { + t.Fatal(err) + } + + if string(decrypted) != string(plaintext) { + t.Errorf("got %q, want %q", decrypted, plaintext) + } +} + +func TestEncryptProducesDifferentCiphertexts(t *testing.T) { + key := make([]byte, sessionKeySize) + plaintext := []byte("same-input") + + enc1, err := Encrypt(key, plaintext) + if err != nil { + t.Fatal(err) + } + enc2, err := Encrypt(key, plaintext) + if err != nil { + t.Fatal(err) + } + + if string(enc1) == string(enc2) { + t.Error("two encryptions of the same plaintext should produce different ciphertext (random nonce)") + } +} + +func TestDecryptTamperedCiphertext(t *testing.T) { + key := make([]byte, sessionKeySize) + plaintext := []byte("sensitive-data") + + encrypted, err := Encrypt(key, plaintext) + if err != nil { + t.Fatal(err) + } + + // Tamper with a byte in the ciphertext + encrypted[len(encrypted)-1] ^= 0xff + + _, err = Decrypt(key, encrypted) + if err == nil { + t.Error("expected error when decrypting tampered ciphertext") + } +} + +func TestDecryptWrongKey(t *testing.T) { + key1 := make([]byte, sessionKeySize) + key2 := make([]byte, sessionKeySize) + key2[0] = 1 + + encrypted, err := Encrypt(key1, []byte("secret")) + if err != nil { + t.Fatal(err) + } + + _, err = Decrypt(key2, encrypted) + if err == nil { + t.Error("expected error when decrypting with wrong key") + } +} + +func TestDecryptTooShortCiphertext(t *testing.T) { + key := make([]byte, sessionKeySize) + _, err := Decrypt(key, []byte("short")) + if err == nil { + t.Error("expected error for short ciphertext") + } +} + +func TestLoadOrGenerateKey_NewKey(t *testing.T) { + dir := t.TempDir() + + key, isNew, err := LoadOrGenerateKey(dir) + if err != nil { + t.Fatal(err) + } + if !isNew { + t.Error("expected isNew=true for first call") + } + if len(key) != sessionKeySize { + t.Errorf("key length = %d, want %d", len(key), sessionKeySize) + } + + // Verify file exists with correct permissions + info, err := os.Stat(filepath.Join(dir, sessionKeyFile)) + if err != nil { + t.Fatal(err) + } + if info.Mode().Perm() != 0o600 { + t.Errorf("key file permissions = %o, want 0600", info.Mode().Perm()) + } +} + +func TestLoadOrGenerateKey_ExistingKey(t *testing.T) { + dir := t.TempDir() + + key1, _, err := LoadOrGenerateKey(dir) + if err != nil { + t.Fatal(err) + } + + key2, isNew, err := LoadOrGenerateKey(dir) + if err != nil { + t.Fatal(err) + } + if isNew { + t.Error("expected isNew=false for second call") + } + if string(key1) != string(key2) { + t.Error("second load should return same key") + } +} + +func TestLoadOrGenerateKey_CorruptKeyFile(t *testing.T) { + dir := t.TempDir() + + // Write a corrupt (too short) key file + if err := os.WriteFile(filepath.Join(dir, sessionKeyFile), []byte("short"), 0o600); err != nil { + t.Fatal(err) + } + + _, _, err := LoadOrGenerateKey(dir) + if err == nil { + t.Fatal("expected error when key file is corrupt") + } + if !strings.Contains(err.Error(), "corrupt") { + t.Errorf("error should mention corruption, got: %v", err) + } +} + +func TestLoadOrGenerateKey_UnreadableKeyFile(t *testing.T) { + dir := t.TempDir() + + // Write a key file with no read permissions + keyPath := filepath.Join(dir, sessionKeyFile) + if err := os.WriteFile(keyPath, []byte("data"), 0o000); err != nil { + t.Fatal(err) + } + + _, _, err := LoadOrGenerateKey(dir) + if err == nil { + t.Fatal("expected error when key file is unreadable") + } +} + +func TestGeneratePlaceholder(t *testing.T) { + p, err := GeneratePlaceholder("gw-test123") + if err != nil { + t.Fatal(err) + } + + if len(p) == 0 { + t.Error("placeholder should not be empty") + } + + // Check prefix + expected := PlaceholderPrefix + PlaceholderVersion + ":gw-test123:" + if p[:len(expected)] != expected { + t.Errorf("placeholder prefix = %q, want %q", p[:len(expected)], expected) + } + + // Check hex tail length (32 hex chars = 16 bytes) + tail := p[len(expected):] + if len(tail) != 32 { + t.Errorf("hex tail length = %d, want 32", len(tail)) + } +} + +func TestGeneratePlaceholder_Uniqueness(t *testing.T) { + seen := make(map[string]bool) + for i := 0; i < 10000; i++ { + p, err := GeneratePlaceholder("test") + if err != nil { + t.Fatal(err) + } + if seen[p] { + t.Fatalf("duplicate placeholder at iteration %d", i) + } + seen[p] = true + } +} + +func TestGeneratePlaceholder_Global(t *testing.T) { + p, err := GeneratePlaceholder("global") + if err != nil { + t.Fatal(err) + } + expected := PlaceholderPrefix + PlaceholderVersion + ":global:" + if p[:len(expected)] != expected { + t.Errorf("global placeholder prefix = %q, want %q", p[:len(expected)], expected) + } +} + +func TestMaskCredentialValue(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"sk-ant-api03-abcdef-xyz", "sk-ant***xyz"}, + {"short", "***rt"}, + {"ab", "***"}, + {"", "***"}, + {"123456789", "123456***789"}, + {"12345678", "***78"}, + } + for _, tt := range tests { + got := MaskCredentialValue(tt.input) + if got != tt.want { + t.Errorf("MaskCredentialValue(%q) = %q, want %q", tt.input, got, tt.want) + } + } +} diff --git a/internal/greyproxy/credential_store.go b/internal/greyproxy/credential_store.go new file mode 100644 index 0000000..a8c0d4d --- /dev/null +++ b/internal/greyproxy/credential_store.go @@ -0,0 +1,401 @@ +package greyproxy + +import ( + "context" + "encoding/json" + "log" + "net/http" + "strings" + "sync" + "sync/atomic" + "time" +) + +// Session event types. +const ( + EventSessionCreated = "session.created" + EventSessionExpired = "session.expired" + EventSessionDeleted = "session.deleted" + EventSessionHeartbeat = "session.heartbeat" + EventSessionSubstitution = "session.substitution" +) + +// SubstitutionResult holds the outcome of a credential substitution pass. +type SubstitutionResult struct { + Count int + Labels []string + SessionIDs []string +} + +// CredentialStore provides fast in-memory credential placeholder lookup +// backed by encrypted DB persistence. +type CredentialStore struct { + mu sync.RWMutex + + // placeholder -> decrypted real credential + lookup map[string]string + + // placeholder -> session_id (for substitution counting) + sessionMap map[string]string + + // placeholder -> human-readable label (e.g. "OPENAI_API_KEY") + labelsMap map[string]string + + // pending substitution counts per session (batched to reduce DB writes) + pendingCounts map[string]*atomic.Int64 + + db *DB + encryptionKey []byte + bus *EventBus +} + +// NewCredentialStore creates a new store and loads existing sessions/credentials from the DB. +func NewCredentialStore(db *DB, encryptionKey []byte, bus *EventBus) (*CredentialStore, error) { + cs := &CredentialStore{ + lookup: make(map[string]string), + sessionMap: make(map[string]string), + labelsMap: make(map[string]string), + pendingCounts: make(map[string]*atomic.Int64), + db: db, + encryptionKey: encryptionKey, + bus: bus, + } + + if err := cs.loadFromDB(); err != nil { + return nil, err + } + + return cs, nil +} + +// loadFromDB rebuilds the in-memory lookup from all active sessions and global credentials. +func (cs *CredentialStore) loadFromDB() error { + now := time.Now().UTC() + + sessions, err := LoadAllSessions(cs.db) + if err != nil { + return err + } + + for _, s := range sessions { + if s.ExpiresAt.Before(now) { + continue + } + mappings, err := DecryptSessionMappings(&s, cs.encryptionKey) + if err != nil { + log.Printf("[credential_store] WARN: failed to decrypt session %s (stale key?), skipping", s.SessionID) + continue + } + labels := GetSessionLabels(&s) + for placeholder, real := range mappings { + cs.lookup[placeholder] = real + cs.sessionMap[placeholder] = s.SessionID + if label, ok := labels[placeholder]; ok { + cs.labelsMap[placeholder] = label + } + } + } + + creds, err := ListGlobalCredentials(cs.db) + if err != nil { + return err + } + + for _, c := range creds { + value, err := DecryptGlobalCredentialValue(&c, cs.encryptionKey) + if err != nil { + log.Printf("[credential_store] WARN: failed to decrypt global credential %s, skipping", c.ID) + continue + } + cs.lookup[c.Placeholder] = value + cs.labelsMap[c.Placeholder] = c.Label + } + + return nil +} + +// RegisterSession adds a session's credential mappings to the in-memory store. +func (cs *CredentialStore) RegisterSession(session *Session, mappings map[string]string) { + cs.mu.Lock() + defer cs.mu.Unlock() + + // Remove any old entries for this session first + cs.removeSessionLocked(session.SessionID) + + labels := GetSessionLabels(session) + for placeholder, real := range mappings { + cs.lookup[placeholder] = real + cs.sessionMap[placeholder] = session.SessionID + if label, ok := labels[placeholder]; ok { + cs.labelsMap[placeholder] = label + } + } + cs.pendingCounts[session.SessionID] = &atomic.Int64{} + + if cs.bus != nil { + cs.bus.Publish(Event{Type: EventSessionCreated, Data: session.SessionID}) + } +} + +// UnregisterSession removes all credential mappings for a session. +func (cs *CredentialStore) UnregisterSession(sessionID string) { + cs.mu.Lock() + defer cs.mu.Unlock() + + cs.removeSessionLocked(sessionID) + + if cs.bus != nil { + cs.bus.Publish(Event{Type: EventSessionDeleted, Data: sessionID}) + } +} + +// removeSessionLocked removes entries for a session (caller must hold write lock). +func (cs *CredentialStore) removeSessionLocked(sessionID string) { + for placeholder, sid := range cs.sessionMap { + if sid == sessionID { + delete(cs.lookup, placeholder) + delete(cs.sessionMap, placeholder) + delete(cs.labelsMap, placeholder) + } + } + delete(cs.pendingCounts, sessionID) +} + +// RegisterGlobalCredential adds a global credential to the in-memory store. +func (cs *CredentialStore) RegisterGlobalCredential(placeholder, value, label string) { + cs.mu.Lock() + defer cs.mu.Unlock() + cs.lookup[placeholder] = value + cs.labelsMap[placeholder] = label +} + +// UnregisterGlobalCredential removes a global credential from the in-memory store. +func (cs *CredentialStore) UnregisterGlobalCredential(placeholder string) { + cs.mu.Lock() + defer cs.mu.Unlock() + delete(cs.lookup, placeholder) + delete(cs.labelsMap, placeholder) +} + +// SubstituteRequest scans HTTP request headers and URL query parameters +// for credential placeholders and replaces them with real values. +// Returns a SubstitutionResult with the count, matched labels, and session IDs. +func (cs *CredentialStore) SubstituteRequest(req *http.Request) SubstitutionResult { + cs.mu.RLock() + defer cs.mu.RUnlock() + + if len(cs.lookup) == 0 { + return SubstitutionResult{} + } + + count := 0 + sessionsUsed := make(map[string]bool) + labelsUsed := make(map[string]bool) + + // Scan headers + for key, values := range req.Header { + for i, v := range values { + if !strings.Contains(v, PlaceholderPrefix) { + continue + } + replaced := cs.replaceInString(v, sessionsUsed, labelsUsed) + if replaced != v { + req.Header[key][i] = replaced + count++ + } + } + } + + // Scan URL query parameters + q := req.URL.Query() + qChanged := false + for key, values := range q { + for i, v := range values { + if !strings.Contains(v, PlaceholderPrefix) { + continue + } + replaced := cs.replaceInString(v, sessionsUsed, labelsUsed) + if replaced != v { + q[key][i] = replaced + qChanged = true + count++ + } + } + } + if qChanged { + req.URL.RawQuery = q.Encode() + } + + // Track substitution counts per session + if count > 0 { + cs.trackSubstitutions(sessionsUsed) + } + + // Build result + labels := make([]string, 0, len(labelsUsed)) + for l := range labelsUsed { + labels = append(labels, l) + } + sessionIDs := make([]string, 0, len(sessionsUsed)) + for sid := range sessionsUsed { + sessionIDs = append(sessionIDs, sid) + } + + return SubstitutionResult{Count: count, Labels: labels, SessionIDs: sessionIDs} +} + +// replaceInString replaces all placeholder occurrences in a string +// and records which sessions and labels were involved. +// Caller must hold at least a read lock. +func (cs *CredentialStore) replaceInString(s string, sessionsUsed, labelsUsed map[string]bool) string { + for placeholder, real := range cs.lookup { + if strings.Contains(s, placeholder) { + s = strings.ReplaceAll(s, placeholder, real) + if sid, ok := cs.sessionMap[placeholder]; ok { + sessionsUsed[sid] = true + } + if label, ok := cs.labelsMap[placeholder]; ok { + labelsUsed[label] = true + } + } + } + return s +} + +// trackSubstitutions increments pending counters for the given sessions. +// Caller must hold at least a read lock. Only reads pendingCounts map; +// counters are pre-allocated during RegisterSession. +func (cs *CredentialStore) trackSubstitutions(sessionsUsed map[string]bool) { + for sid := range sessionsUsed { + if counter, ok := cs.pendingCounts[sid]; ok { + counter.Add(1) + } + } +} + +// Size returns the number of credential mappings in the store. +func (cs *CredentialStore) Size() int { + cs.mu.RLock() + defer cs.mu.RUnlock() + return len(cs.lookup) +} + +// SessionCount returns the number of unique sessions with active credentials. +func (cs *CredentialStore) SessionCount() int { + cs.mu.RLock() + defer cs.mu.RUnlock() + seen := make(map[string]bool) + for _, sid := range cs.sessionMap { + seen[sid] = true + } + return len(seen) +} + +// StartCleanupLoop runs a periodic goroutine that: +// 1. Removes expired sessions from DB and memory +// 2. Flushes pending substitution counts to DB +// The loop runs every interval until ctx is cancelled. +func (cs *CredentialStore) StartCleanupLoop(ctx context.Context, interval time.Duration) { + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + cs.cleanup() + cs.flushSubstitutionCounts() + } + } + }() +} + +// cleanup removes expired sessions. +func (cs *CredentialStore) cleanup() { + expiredIDs, err := DeleteExpiredSessions(cs.db) + if err != nil { + log.Printf("[credential_store] WARN: cleanup error: %v", err) + return + } + + if len(expiredIDs) > 0 { + cs.mu.Lock() + for _, id := range expiredIDs { + cs.removeSessionLocked(id) + } + cs.mu.Unlock() + + for _, id := range expiredIDs { + log.Printf("[credential_store] session %s expired", id) + if cs.bus != nil { + cs.bus.Publish(Event{Type: EventSessionExpired, Data: id}) + } + } + } +} + +// flushSubstitutionCounts writes pending counts to the DB. +func (cs *CredentialStore) flushSubstitutionCounts() { + cs.mu.RLock() + toFlush := make(map[string]int64) + for sid, counter := range cs.pendingCounts { + if v := counter.Swap(0); v > 0 { + toFlush[sid] = v + } + } + cs.mu.RUnlock() + + for sid, delta := range toFlush { + if err := IncrementSubstitutionCount(cs.db, sid, delta); err != nil { + log.Printf("[credential_store] WARN: failed to flush substitution count for %s: %v", sid, err) + continue + } + if cs.bus != nil { + cs.bus.Publish(Event{Type: EventSessionSubstitution, Data: sid}) + } + } +} + +// PurgeUnreadableCredentials removes sessions and global credentials that +// cannot be decrypted (e.g., after key rotation). Call on startup if a new +// key was generated. Returns the number of sessions and credentials purged. +func (cs *CredentialStore) PurgeUnreadableCredentials() (sessions int, globals int, err error) { + allSessions, err := LoadAllSessions(cs.db) + if err != nil { + return 0, 0, err + } + + for _, s := range allSessions { + if _, decErr := DecryptSessionMappings(&s, cs.encryptionKey); decErr != nil { + if _, delErr := DeleteSession(cs.db, s.SessionID); delErr == nil { + sessions++ + } + } + } + + allCreds, err := ListGlobalCredentials(cs.db) + if err != nil { + return sessions, 0, err + } + + for _, c := range allCreds { + if _, decErr := DecryptGlobalCredentialValue(&c, cs.encryptionKey); decErr != nil { + if _, delErr := DeleteGlobalCredential(cs.db, c.ID); delErr == nil { + globals++ + } + } + } + + return sessions, globals, nil +} + +// GetSessionLabels returns the labels map for a session from its JSON field. +func GetSessionLabels(s *Session) map[string]string { + var labels map[string]string + if err := json.Unmarshal([]byte(s.LabelsJSON), &labels); err != nil { + return make(map[string]string) + } + return labels +} diff --git a/internal/greyproxy/credential_store_test.go b/internal/greyproxy/credential_store_test.go new file mode 100644 index 0000000..b27ef76 --- /dev/null +++ b/internal/greyproxy/credential_store_test.go @@ -0,0 +1,978 @@ +package greyproxy + +import ( + "context" + "net/http" + "net/url" + "sync" + "testing" + "time" + + _ "modernc.org/sqlite" +) + +func testEncryptionKey() []byte { + key := make([]byte, sessionKeySize) + for i := range key { + key[i] = byte(i) + } + return key +} + +func setupCredentialStore(t *testing.T) (*CredentialStore, *DB) { + t.Helper() + db := setupTestDB(t) + bus := NewEventBus() + key := testEncryptionKey() + + cs, err := NewCredentialStore(db, key, bus) + if err != nil { + t.Fatal(err) + } + return cs, db +} + +func TestCredentialStore_SubstituteRequest_HeaderExactMatch(t *testing.T) { + cs, _ := setupCredentialStore(t) + + placeholder := "greyproxy:credential:v1:test:aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa1" + realKey := "sk-ant-api03-real-key" + + cs.RegisterSession(&Session{SessionID: "test"}, map[string]string{ + placeholder: realKey, + }) + + req := &http.Request{ + Header: http.Header{ + "Authorization": []string{"Bearer " + placeholder}, + }, + URL: &url.URL{Path: "/v1/chat"}, + } + + result := cs.SubstituteRequest(req) + if result.Count != 1 { + t.Errorf("substitution count = %d, want 1", result.Count) + } + if req.Header.Get("Authorization") != "Bearer "+realKey { + t.Errorf("got header %q, want %q", req.Header.Get("Authorization"), "Bearer "+realKey) + } +} + +func TestCredentialStore_SubstituteRequest_NoMatch(t *testing.T) { + cs, _ := setupCredentialStore(t) + + cs.RegisterSession(&Session{SessionID: "test"}, map[string]string{ + "greyproxy:credential:v1:test:aaaa": "real", + }) + + req := &http.Request{ + Header: http.Header{ + "Authorization": []string{"Bearer sk-regular-key"}, + }, + URL: &url.URL{Path: "/v1/chat"}, + } + + result := cs.SubstituteRequest(req) + if result.Count != 0 { + t.Errorf("substitution count = %d, want 0", result.Count) + } + if req.Header.Get("Authorization") != "Bearer sk-regular-key" { + t.Error("header should not be modified") + } +} + +func TestCredentialStore_SubstituteRequest_QueryParam(t *testing.T) { + cs, _ := setupCredentialStore(t) + + placeholder := "greyproxy:credential:v1:test:bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb" + realKey := "actual-api-key" + + cs.RegisterSession(&Session{SessionID: "test"}, map[string]string{ + placeholder: realKey, + }) + + req := &http.Request{ + Header: http.Header{}, + URL: &url.URL{ + Path: "/api/data", + RawQuery: "api_key=" + placeholder + "&other=value", + }, + } + + result := cs.SubstituteRequest(req) + if result.Count != 1 { + t.Errorf("substitution count = %d, want 1", result.Count) + } + if req.URL.Query().Get("api_key") != realKey { + t.Errorf("got query param %q, want %q", req.URL.Query().Get("api_key"), realKey) + } + if req.URL.Query().Get("other") != "value" { + t.Error("other query params should be preserved") + } +} + +func TestCredentialStore_SubstituteRequest_MultipleHeaders(t *testing.T) { + cs, _ := setupCredentialStore(t) + + p1 := "greyproxy:credential:v1:s1:aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa1" + p2 := "greyproxy:credential:v1:s1:aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa2" + + cs.RegisterSession(&Session{SessionID: "s1"}, map[string]string{ + p1: "real-key-1", + p2: "real-key-2", + }) + + req := &http.Request{ + Header: http.Header{ + "Authorization": []string{p1}, + "X-Api-Key": []string{p2}, + }, + URL: &url.URL{Path: "/"}, + } + + result := cs.SubstituteRequest(req) + if result.Count != 2 { + t.Errorf("substitution count = %d, want 2", result.Count) + } + if req.Header.Get("Authorization") != "real-key-1" { + t.Errorf("Authorization = %q, want %q", req.Header.Get("Authorization"), "real-key-1") + } + if req.Header.Get("X-Api-Key") != "real-key-2" { + t.Errorf("X-Api-Key = %q, want %q", req.Header.Get("X-Api-Key"), "real-key-2") + } +} + +func TestCredentialStore_SubstituteRequest_EmptyStore(t *testing.T) { + cs, _ := setupCredentialStore(t) + + req := &http.Request{ + Header: http.Header{ + "Authorization": []string{"Bearer something"}, + }, + URL: &url.URL{Path: "/"}, + } + + result := cs.SubstituteRequest(req) + if result.Count != 0 { + t.Errorf("substitution count = %d, want 0", result.Count) + } +} + +func TestCredentialStore_RegisterUnregisterSession(t *testing.T) { + cs, _ := setupCredentialStore(t) + + p := "greyproxy:credential:v1:sess1:cccccccccccccccccccccccccccccccc" + cs.RegisterSession(&Session{SessionID: "sess1"}, map[string]string{ + p: "real", + }) + + if cs.Size() != 1 { + t.Errorf("size = %d, want 1", cs.Size()) + } + + cs.UnregisterSession("sess1") + + if cs.Size() != 0 { + t.Errorf("size = %d, want 0 after unregister", cs.Size()) + } + + // Substitution should no longer work + req := &http.Request{ + Header: http.Header{"Authorization": []string{p}}, + URL: &url.URL{Path: "/"}, + } + res := cs.SubstituteRequest(req) + if res.Count != 0 { + t.Errorf("substitution count = %d after unregister, want 0", res.Count) + } +} + +func TestCredentialStore_RegisterGlobalCredential(t *testing.T) { + cs, _ := setupCredentialStore(t) + + p := "greyproxy:credential:v1:global:dddddddddddddddddddddddddddddddd" + cs.RegisterGlobalCredential(p, "global-secret", "GLOBAL_KEY") + + req := &http.Request{ + Header: http.Header{"X-Api-Key": []string{p}}, + URL: &url.URL{Path: "/"}, + } + + result := cs.SubstituteRequest(req) + if result.Count != 1 { + t.Errorf("substitution count = %d, want 1", result.Count) + } + if req.Header.Get("X-Api-Key") != "global-secret" { + t.Errorf("got %q, want %q", req.Header.Get("X-Api-Key"), "global-secret") + } +} + +func TestCredentialStore_SessionUpsert(t *testing.T) { + cs, _ := setupCredentialStore(t) + + p1 := "greyproxy:credential:v1:s1:eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee" + p2 := "greyproxy:credential:v1:s1:ffffffffffffffffffffffffffffffff" + + cs.RegisterSession(&Session{SessionID: "s1"}, map[string]string{p1: "old-key"}) + if cs.Size() != 1 { + t.Fatalf("size = %d, want 1", cs.Size()) + } + + // Upsert with new mappings + cs.RegisterSession(&Session{SessionID: "s1"}, map[string]string{p2: "new-key"}) + if cs.Size() != 1 { + t.Errorf("size after upsert = %d, want 1 (old entry should be removed)", cs.Size()) + } + + // Old placeholder should not work + req := &http.Request{ + Header: http.Header{"Authorization": []string{p1}}, + URL: &url.URL{Path: "/"}, + } + res := cs.SubstituteRequest(req) + if res.Count != 0 { + t.Error("old placeholder should not be substituted after upsert") + } + + // New placeholder should work + req = &http.Request{ + Header: http.Header{"Authorization": []string{p2}}, + URL: &url.URL{Path: "/"}, + } + res = cs.SubstituteRequest(req) + if res.Count != 1 { + t.Error("new placeholder should be substituted after upsert") + } +} + +func TestCredentialStore_ConcurrentAccess(t *testing.T) { + cs, _ := setupCredentialStore(t) + + p := "greyproxy:credential:v1:conc:11111111111111111111111111111111" + cs.RegisterSession(&Session{SessionID: "conc"}, map[string]string{p: "real"}) + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + req := &http.Request{ + Header: http.Header{"Authorization": []string{p}}, + URL: &url.URL{Path: "/"}, + } + cs.SubstituteRequest(req) + }() + } + + // Concurrent writes + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + cs.RegisterSession(&Session{SessionID: "conc"}, map[string]string{p: "real"}) + }() + } + + wg.Wait() +} + +func TestCredentialStore_SessionCount(t *testing.T) { + cs, _ := setupCredentialStore(t) + + cs.RegisterSession(&Session{SessionID: "s1"}, map[string]string{ + "greyproxy:credential:v1:s1:aaaa": "r1", + }) + cs.RegisterSession(&Session{SessionID: "s2"}, map[string]string{ + "greyproxy:credential:v1:s2:bbbb": "r2", + }) + + if cs.SessionCount() != 2 { + t.Errorf("session count = %d, want 2", cs.SessionCount()) + } + + cs.UnregisterSession("s1") + if cs.SessionCount() != 1 { + t.Errorf("session count = %d, want 1", cs.SessionCount()) + } +} + +// --- CRUD Tests --- + +func TestSessionCreateAndGet(t *testing.T) { + db := setupTestDB(t) + key := testEncryptionKey() + + session, err := CreateOrUpdateSession(db, SessionCreateInput{ + SessionID: "gw-test1", + ContainerName: "sandbox-1", + Mappings: map[string]string{ + "greyproxy:credential:v1:gw-test1:aaaa": "sk-real-key", + }, + Labels: map[string]string{ + "greyproxy:credential:v1:gw-test1:aaaa": "ANTHROPIC_API_KEY", + }, + TTLSeconds: 300, + }, key) + if err != nil { + t.Fatal(err) + } + if session.SessionID != "gw-test1" { + t.Errorf("session_id = %q, want %q", session.SessionID, "gw-test1") + } + if session.TTLSeconds != 300 { + t.Errorf("ttl = %d, want 300", session.TTLSeconds) + } + + // Read back + got, err := GetSession(db, "gw-test1") + if err != nil { + t.Fatal(err) + } + if got.ContainerName != "sandbox-1" { + t.Errorf("container = %q, want %q", got.ContainerName, "sandbox-1") + } + + // Decrypt and verify mappings + mappings, err := DecryptSessionMappings(got, key) + if err != nil { + t.Fatal(err) + } + if mappings["greyproxy:credential:v1:gw-test1:aaaa"] != "sk-real-key" { + t.Error("decrypted mapping does not match") + } + + // Verify labels + labels, err := ParseSessionLabels(got) + if err != nil { + t.Fatal(err) + } + if labels["greyproxy:credential:v1:gw-test1:aaaa"] != "ANTHROPIC_API_KEY" { + t.Error("label does not match") + } +} + +func TestSessionUpsert(t *testing.T) { + db := setupTestDB(t) + key := testEncryptionKey() + + input := SessionCreateInput{ + SessionID: "gw-upsert", + ContainerName: "sandbox", + Mappings: map[string]string{"p1": "v1"}, + Labels: map[string]string{"p1": "L1"}, + TTLSeconds: 300, + } + + _, err := CreateOrUpdateSession(db, input, key) + if err != nil { + t.Fatal(err) + } + + // Upsert with different mappings + input.Mappings = map[string]string{"p2": "v2"} + input.Labels = map[string]string{"p2": "L2"} + _, err = CreateOrUpdateSession(db, input, key) + if err != nil { + t.Fatal(err) + } + + got, err := GetSession(db, "gw-upsert") + if err != nil { + t.Fatal(err) + } + mappings, err := DecryptSessionMappings(got, key) + if err != nil { + t.Fatal(err) + } + if _, ok := mappings["p1"]; ok { + t.Error("old mapping should be replaced on upsert") + } + if mappings["p2"] != "v2" { + t.Error("new mapping should be present after upsert") + } +} + +func TestSessionHeartbeat(t *testing.T) { + db := setupTestDB(t) + key := testEncryptionKey() + + _, err := CreateOrUpdateSession(db, SessionCreateInput{ + SessionID: "gw-hb", + ContainerName: "sandbox", + Mappings: map[string]string{"p": "v"}, + Labels: map[string]string{}, + TTLSeconds: 300, + }, key) + if err != nil { + t.Fatal(err) + } + + before, _ := GetSession(db, "gw-hb") + + // SQLite datetime has second-level precision, so we need to wait at least 1s + time.Sleep(1100 * time.Millisecond) + + updated, err := HeartbeatSession(db, "gw-hb") + if err != nil { + t.Fatal(err) + } + if updated == nil { + t.Fatal("heartbeat returned nil") + } + if !updated.ExpiresAt.After(before.ExpiresAt) { + t.Error("expires_at should be extended after heartbeat") + } +} + +func TestSessionHeartbeat_Expired(t *testing.T) { + db := setupTestDB(t) + key := testEncryptionKey() + + _, err := CreateOrUpdateSession(db, SessionCreateInput{ + SessionID: "gw-expired", + ContainerName: "sandbox", + Mappings: map[string]string{"p": "v"}, + Labels: map[string]string{}, + TTLSeconds: 1, // 1 second TTL + }, key) + if err != nil { + t.Fatal(err) + } + + time.Sleep(1100 * time.Millisecond) + + updated, err := HeartbeatSession(db, "gw-expired") + if err != nil { + t.Fatal(err) + } + if updated != nil { + t.Error("heartbeat should return nil for expired session") + } +} + +func TestSessionDelete(t *testing.T) { + db := setupTestDB(t) + key := testEncryptionKey() + + _, err := CreateOrUpdateSession(db, SessionCreateInput{ + SessionID: "gw-del", + ContainerName: "sandbox", + Mappings: map[string]string{"p": "v"}, + Labels: map[string]string{}, + TTLSeconds: 300, + }, key) + if err != nil { + t.Fatal(err) + } + + deleted, err := DeleteSession(db, "gw-del") + if err != nil { + t.Fatal(err) + } + if !deleted { + t.Error("expected deletion to succeed") + } + + got, err := GetSession(db, "gw-del") + if err == nil && got != nil { + t.Error("session should not exist after delete") + } + + // Delete non-existent + deleted, err = DeleteSession(db, "gw-del") + if err != nil { + t.Fatal(err) + } + if deleted { + t.Error("deleting non-existent should return false") + } +} + +func TestListSessions(t *testing.T) { + db := setupTestDB(t) + key := testEncryptionKey() + + for _, id := range []string{"s1", "s2", "s3"} { + _, err := CreateOrUpdateSession(db, SessionCreateInput{ + SessionID: id, + ContainerName: "sandbox", + Mappings: map[string]string{"p": "v"}, + Labels: map[string]string{}, + TTLSeconds: 300, + }, key) + if err != nil { + t.Fatal(err) + } + } + + sessions, err := ListSessions(db) + if err != nil { + t.Fatal(err) + } + if len(sessions) != 3 { + t.Errorf("got %d sessions, want 3", len(sessions)) + } +} + +func TestDeleteExpiredSessions(t *testing.T) { + db := setupTestDB(t) + key := testEncryptionKey() + + // Create session with 1s TTL + _, err := CreateOrUpdateSession(db, SessionCreateInput{ + SessionID: "gw-expire", + ContainerName: "sandbox", + Mappings: map[string]string{"p": "v"}, + Labels: map[string]string{}, + TTLSeconds: 1, + }, key) + if err != nil { + t.Fatal(err) + } + + // Create session with long TTL + _, err = CreateOrUpdateSession(db, SessionCreateInput{ + SessionID: "gw-keep", + ContainerName: "sandbox", + Mappings: map[string]string{"p": "v"}, + Labels: map[string]string{}, + TTLSeconds: 300, + }, key) + if err != nil { + t.Fatal(err) + } + + time.Sleep(1100 * time.Millisecond) + + expired, err := DeleteExpiredSessions(db) + if err != nil { + t.Fatal(err) + } + if len(expired) != 1 || expired[0] != "gw-expire" { + t.Errorf("expired = %v, want [gw-expire]", expired) + } + + remaining, err := ListSessions(db) + if err != nil { + t.Fatal(err) + } + if len(remaining) != 1 { + t.Errorf("remaining = %d, want 1", len(remaining)) + } +} + +func TestSubstitutionCount(t *testing.T) { + db := setupTestDB(t) + key := testEncryptionKey() + + _, err := CreateOrUpdateSession(db, SessionCreateInput{ + SessionID: "gw-count", + ContainerName: "sandbox", + Mappings: map[string]string{"p": "v"}, + Labels: map[string]string{}, + TTLSeconds: 300, + }, key) + if err != nil { + t.Fatal(err) + } + + if err := IncrementSubstitutionCount(db, "gw-count", 5); err != nil { + t.Fatal(err) + } + + got, err := GetSession(db, "gw-count") + if err != nil { + t.Fatal(err) + } + if got.SubstitutionCount != 5 { + t.Errorf("substitution_count = %d, want 5", got.SubstitutionCount) + } +} + +// --- Global Credential CRUD Tests --- + +func TestGlobalCredentialCreateAndList(t *testing.T) { + db := setupTestDB(t) + key := testEncryptionKey() + + cred, err := CreateGlobalCredential(db, GlobalCredentialCreateInput{ + Label: "ANTHROPIC_API_KEY", + Value: "sk-ant-api03-abcdefghijk", + }, key) + if err != nil { + t.Fatal(err) + } + if cred.Label != "ANTHROPIC_API_KEY" { + t.Errorf("label = %q, want %q", cred.Label, "ANTHROPIC_API_KEY") + } + if cred.ValuePreview != "sk-ant***ijk" { + t.Errorf("preview = %q, want %q", cred.ValuePreview, "sk-ant***ijk") + } + if cred.Placeholder == "" { + t.Error("placeholder should not be empty") + } + + // List + creds, err := ListGlobalCredentials(db) + if err != nil { + t.Fatal(err) + } + if len(creds) != 1 { + t.Errorf("got %d credentials, want 1", len(creds)) + } + + // Decrypt and verify + value, err := DecryptGlobalCredentialValue(&creds[0], key) + if err != nil { + t.Fatal(err) + } + if value != "sk-ant-api03-abcdefghijk" { + t.Errorf("decrypted value = %q", value) + } +} + +func TestGlobalCredentialDuplicateLabel(t *testing.T) { + db := setupTestDB(t) + key := testEncryptionKey() + + _, err := CreateGlobalCredential(db, GlobalCredentialCreateInput{ + Label: "MY_KEY", + Value: "val1", + }, key) + if err != nil { + t.Fatal(err) + } + + _, err = CreateGlobalCredential(db, GlobalCredentialCreateInput{ + Label: "MY_KEY", + Value: "val2", + }, key) + if err == nil { + t.Error("expected error for duplicate label") + } +} + +func TestGlobalCredentialDelete(t *testing.T) { + db := setupTestDB(t) + key := testEncryptionKey() + + cred, err := CreateGlobalCredential(db, GlobalCredentialCreateInput{ + Label: "DEL_KEY", + Value: "val", + }, key) + if err != nil { + t.Fatal(err) + } + + deleted, err := DeleteGlobalCredential(db, cred.ID) + if err != nil { + t.Fatal(err) + } + if !deleted { + t.Error("delete should succeed") + } + + creds, err := ListGlobalCredentials(db) + if err != nil { + t.Fatal(err) + } + if len(creds) != 0 { + t.Errorf("got %d credentials after delete, want 0", len(creds)) + } +} + +func TestCredentialStore_LoadFromDB(t *testing.T) { + db := setupTestDB(t) + key := testEncryptionKey() + bus := NewEventBus() + + placeholder := "greyproxy:credential:v1:reload:aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa1" + + // Create session in DB + _, err := CreateOrUpdateSession(db, SessionCreateInput{ + SessionID: "reload-test", + ContainerName: "sandbox", + Mappings: map[string]string{placeholder: "real-key"}, + Labels: map[string]string{placeholder: "MY_KEY"}, + TTLSeconds: 300, + }, key) + if err != nil { + t.Fatal(err) + } + + // Create global credential in DB + _, err = CreateGlobalCredential(db, GlobalCredentialCreateInput{ + Label: "GLOBAL_KEY", + Value: "global-real", + }, key) + if err != nil { + t.Fatal(err) + } + + // Create a new store that should load from DB + cs, err := NewCredentialStore(db, key, bus) + if err != nil { + t.Fatal(err) + } + + // Session credential should be loaded + if cs.Size() != 2 { // 1 session + 1 global + t.Errorf("size = %d, want 2", cs.Size()) + } + + // Verify session placeholder works + req := &http.Request{ + Header: http.Header{"Authorization": []string{placeholder}}, + URL: &url.URL{Path: "/"}, + } + res := cs.SubstituteRequest(req) + if res.Count != 1 { + t.Error("session placeholder should work after DB reload") + } + if req.Header.Get("Authorization") != "real-key" { + t.Errorf("got %q, want %q", req.Header.Get("Authorization"), "real-key") + } +} + +func TestCredentialStore_CleanupLoop(t *testing.T) { + db := setupTestDB(t) + key := testEncryptionKey() + bus := NewEventBus() + + placeholder := "greyproxy:credential:v1:cleanup:aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa1" + + _, err := CreateOrUpdateSession(db, SessionCreateInput{ + SessionID: "cleanup-test", + ContainerName: "sandbox", + Mappings: map[string]string{placeholder: "real"}, + Labels: map[string]string{}, + TTLSeconds: 2, // expires in 2s + }, key) + if err != nil { + t.Fatal(err) + } + + cs, err := NewCredentialStore(db, key, bus) + if err != nil { + t.Fatal(err) + } + + if cs.Size() != 1 { + t.Fatalf("initial size = %d, want 1", cs.Size()) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + cs.StartCleanupLoop(ctx, 500*time.Millisecond) + + // Wait for session to expire and cleanup to run + time.Sleep(3 * time.Second) + + if cs.Size() != 0 { + t.Errorf("size after cleanup = %d, want 0", cs.Size()) + } +} + +func TestCredentialStore_PurgeUnreadableCredentials(t *testing.T) { + db := setupTestDB(t) + key1 := testEncryptionKey() + + _, err := CreateOrUpdateSession(db, SessionCreateInput{ + SessionID: "purge-test", + ContainerName: "sandbox", + Mappings: map[string]string{"p": "v"}, + Labels: map[string]string{}, + TTLSeconds: 300, + }, key1) + if err != nil { + t.Fatal(err) + } + + // Also create a global credential with the old key + _, err = CreateGlobalCredential(db, GlobalCredentialCreateInput{ + Label: "OLD_CRED", + Value: "old-secret-value", + }, key1) + if err != nil { + t.Fatal(err) + } + + // Use a different key (simulating key rotation) + key2 := make([]byte, sessionKeySize) + key2[0] = 99 + + bus := NewEventBus() + cs, err := NewCredentialStore(db, key2, bus) + if err != nil { + t.Fatal(err) + } + + // Both should have been skipped during load + if cs.Size() != 0 { + t.Errorf("size = %d, want 0 (all encrypted with old key)", cs.Size()) + } + + sessions, globals, err := cs.PurgeUnreadableCredentials() + if err != nil { + t.Fatal(err) + } + if sessions != 1 { + t.Errorf("purged sessions = %d, want 1", sessions) + } + if globals != 1 { + t.Errorf("purged globals = %d, want 1", globals) + } + + remainingSessions, err := LoadAllSessions(db) + if err != nil { + t.Fatal(err) + } + if len(remainingSessions) != 0 { + t.Errorf("sessions in DB = %d, want 0 after purge", len(remainingSessions)) + } + + remainingCreds, err := ListGlobalCredentials(db) + if err != nil { + t.Fatal(err) + } + if len(remainingCreds) != 0 { + t.Errorf("global credentials in DB = %d, want 0 after purge", len(remainingCreds)) + } +} + +// --- GetGlobalCredentialsByLabels Tests --- + +func TestGetGlobalCredentialsByLabels(t *testing.T) { + db := setupTestDB(t) + key := testEncryptionKey() + + // Create two global credentials + cred1, err := CreateGlobalCredential(db, GlobalCredentialCreateInput{ + Label: "ANTHROPIC_API_KEY", + Value: "sk-ant-real-key", + }, key) + if err != nil { + t.Fatal(err) + } + _, err = CreateGlobalCredential(db, GlobalCredentialCreateInput{ + Label: "OPENAI_API_KEY", + Value: "sk-oai-real-key", + }, key) + if err != nil { + t.Fatal(err) + } + + t.Run("all found", func(t *testing.T) { + found, missing, err := GetGlobalCredentialsByLabels(db, []string{"ANTHROPIC_API_KEY", "OPENAI_API_KEY"}) + if err != nil { + t.Fatal(err) + } + if len(missing) != 0 { + t.Errorf("unexpected missing labels: %v", missing) + } + if len(found) != 2 { + t.Fatalf("got %d found, want 2", len(found)) + } + if found["ANTHROPIC_API_KEY"].Placeholder != cred1.Placeholder { + t.Errorf("placeholder mismatch for ANTHROPIC_API_KEY") + } + }) + + t.Run("some missing", func(t *testing.T) { + found, missing, err := GetGlobalCredentialsByLabels(db, []string{"ANTHROPIC_API_KEY", "NONEXISTENT"}) + if err != nil { + t.Fatal(err) + } + if len(found) != 1 { + t.Errorf("got %d found, want 1", len(found)) + } + if len(missing) != 1 || missing[0] != "NONEXISTENT" { + t.Errorf("missing = %v, want [NONEXISTENT]", missing) + } + }) + + t.Run("empty labels", func(t *testing.T) { + found, missing, err := GetGlobalCredentialsByLabels(db, nil) + if err != nil { + t.Fatal(err) + } + if found != nil || missing != nil { + t.Error("expected nil for empty labels") + } + }) +} + +func TestSessionWithGlobalCredentials_Substitution(t *testing.T) { + db := setupTestDB(t) + key := testEncryptionKey() + bus := NewEventBus() + + // Create a global credential + globalCred, err := CreateGlobalCredential(db, GlobalCredentialCreateInput{ + Label: "GLOBAL_KEY", + Value: "sk-global-secret", + }, key) + if err != nil { + t.Fatal(err) + } + + // Create a session with only session-specific credentials. + // Global credentials are NOT stored in session mappings; the store + // loads them separately from the global_credentials table. + sessionPlaceholder := "greyproxy:credential:v1:gw-mixed:aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa1" + sessionMappings := map[string]string{ + sessionPlaceholder: "sk-session-secret", + } + + session, err := CreateOrUpdateSession(db, SessionCreateInput{ + SessionID: "gw-mixed", + ContainerName: "sandbox", + Mappings: sessionMappings, + Labels: map[string]string{ + sessionPlaceholder: "SESSION_KEY", + globalCred.Placeholder: "GLOBAL_KEY", + }, + TTLSeconds: 300, + }, key) + if err != nil { + t.Fatal(err) + } + + // Build a credential store (loads both sessions and global creds from DB) + store, err := NewCredentialStore(db, key, bus) + if err != nil { + t.Fatal(err) + } + store.RegisterSession(session, sessionMappings) + + // Test substitution of session credential + req1, _ := http.NewRequest("GET", "https://api.example.com", nil) + req1.Header.Set("Authorization", "Bearer "+sessionPlaceholder) + result1 := store.SubstituteRequest(req1) + if result1.Count != 1 { + t.Fatalf("session cred: count = %d, want 1", result1.Count) + } + if req1.Header.Get("Authorization") != "Bearer sk-session-secret" { + t.Errorf("session cred: got %q", req1.Header.Get("Authorization")) + } + + // Test substitution of global credential (loaded from global_credentials table, not session) + req2, _ := http.NewRequest("GET", "https://api.example.com", nil) + req2.Header.Set("Authorization", "Bearer "+globalCred.Placeholder) + result2 := store.SubstituteRequest(req2) + if result2.Count != 1 { + t.Fatalf("global cred: count = %d, want 1", result2.Count) + } + if req2.Header.Get("Authorization") != "Bearer sk-global-secret" { + t.Errorf("global cred: got %q", req2.Header.Get("Authorization")) + } + + // Verify labels are tracked for both + if len(result1.Labels) != 1 || result1.Labels[0] != "SESSION_KEY" { + t.Errorf("session cred labels = %v, want [SESSION_KEY]", result1.Labels) + } + if len(result2.Labels) != 1 || result2.Labels[0] != "GLOBAL_KEY" { + t.Errorf("global cred labels = %v, want [GLOBAL_KEY]", result2.Labels) + } +} diff --git a/internal/greyproxy/crud.go b/internal/greyproxy/crud.go index 4da3efa..7684926 100644 --- a/internal/greyproxy/crud.go +++ b/internal/greyproxy/crud.go @@ -1119,12 +1119,23 @@ func CreateHttpTransaction(db *DB, input HttpTransactionCreateInput) (*HttpTrans respBody = respBody[:MaxBodyCapture] } + var subCredsJSON sql.NullString + if len(input.SubstitutedCredentials) > 0 { + b, _ := json.Marshal(input.SubstitutedCredentials) + subCredsJSON = sql.NullString{String: string(b), Valid: true} + } + + var sessionID sql.NullString + if input.SessionID != "" { + sessionID = sql.NullString{String: input.SessionID, Valid: true} + } + result, err := db.WriteDB().Exec( `INSERT INTO http_transactions (container_name, destination_host, destination_port, method, url, request_headers, request_body, request_body_size, request_content_type, status_code, response_headers, response_body, response_body_size, response_content_type, - duration_ms, rule_id, result) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + duration_ms, rule_id, result, substituted_credentials, session_id) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, input.ContainerName, input.DestinationHost, input.DestinationPort, input.Method, input.URL, reqHeadersJSON, reqBody, reqBodySize, @@ -1135,6 +1146,7 @@ func CreateHttpTransaction(db *DB, input HttpTransactionCreateInput) (*HttpTrans sql.NullInt64{Int64: input.DurationMs, Valid: input.DurationMs > 0}, sql.NullInt64{Int64: ptrInt64OrZero(input.RuleID), Valid: input.RuleID != nil}, input.Result, + subCredsJSON, sessionID, ) if err != nil { return nil, fmt.Errorf("insert http_transaction: %w", err) @@ -1150,12 +1162,12 @@ func getHttpTransactionByID(conn *sql.DB, id int64) (*HttpTransaction, error) { `SELECT id, timestamp, container_name, destination_host, destination_port, method, url, request_headers, request_body, request_body_size, request_content_type, status_code, response_headers, response_body, response_body_size, response_content_type, - duration_ms, rule_id, result + duration_ms, rule_id, result, substituted_credentials, session_id FROM http_transactions WHERE id = ?`, id, ).Scan(&t.ID, &t.Timestamp, &t.ContainerName, &t.DestinationHost, &t.DestinationPort, &t.Method, &t.URL, &t.RequestHeaders, &t.RequestBody, &t.RequestBodySize, &t.RequestContentType, &t.StatusCode, &t.ResponseHeaders, &t.ResponseBody, &t.ResponseBodySize, &t.ResponseContentType, - &t.DurationMs, &t.RuleID, &t.Result) + &t.DurationMs, &t.RuleID, &t.Result, &t.SubstitutedCredentials, &t.SessionID) if err != nil { if err == sql.ErrNoRows { return nil, nil @@ -1173,6 +1185,7 @@ type TransactionFilter struct { Container string Destination string Method string + SessionID string FromDate *time.Time ToDate *time.Time Limit int @@ -1199,6 +1212,10 @@ func QueryHttpTransactions(db *DB, f TransactionFilter) ([]HttpTransaction, int, where = append(where, "method = ?") args = append(args, f.Method) } + if f.SessionID != "" { + where = append(where, "session_id = ?") + args = append(args, f.SessionID) + } if f.FromDate != nil { where = append(where, "timestamp >= ?") args = append(args, f.FromDate.UTC().Format("2006-01-02 15:04:05")) @@ -1221,7 +1238,7 @@ func QueryHttpTransactions(db *DB, f TransactionFilter) ([]HttpTransaction, int, `SELECT id, timestamp, container_name, destination_host, destination_port, method, url, request_headers, NULL, request_body_size, request_content_type, status_code, response_headers, NULL, response_body_size, response_content_type, - duration_ms, rule_id, result + duration_ms, rule_id, result, substituted_credentials, session_id FROM http_transactions WHERE `+whereClause+` ORDER BY timestamp DESC LIMIT ? OFFSET ?`, append(args, f.Limit, f.Offset)..., ) @@ -1236,7 +1253,7 @@ func QueryHttpTransactions(db *DB, f TransactionFilter) ([]HttpTransaction, int, if err := rows.Scan(&t.ID, &t.Timestamp, &t.ContainerName, &t.DestinationHost, &t.DestinationPort, &t.Method, &t.URL, &t.RequestHeaders, &t.RequestBody, &t.RequestBodySize, &t.RequestContentType, &t.StatusCode, &t.ResponseHeaders, &t.ResponseBody, &t.ResponseBodySize, &t.ResponseContentType, - &t.DurationMs, &t.RuleID, &t.Result); err != nil { + &t.DurationMs, &t.RuleID, &t.Result, &t.SubstitutedCredentials, &t.SessionID); err != nil { return nil, 0, err } txns = append(txns, t) diff --git a/internal/greyproxy/crud_test.go b/internal/greyproxy/crud_test.go index 3a8d03c..4f327fc 100644 --- a/internal/greyproxy/crud_test.go +++ b/internal/greyproxy/crud_test.go @@ -637,8 +637,8 @@ func TestMigrations(t *testing.T) { // Verify migration versions were recorded var count int db.ReadDB().QueryRow("SELECT COUNT(*) FROM schema_migrations").Scan(&count) - if count != 7 { - t.Errorf("expected 7 migration versions, got %d", count) + if count != 9 { + t.Errorf("expected 9 migration versions, got %d", count) } } diff --git a/internal/greyproxy/migrations.go b/internal/greyproxy/migrations.go index ce5a13b..33375bb 100644 --- a/internal/greyproxy/migrations.go +++ b/internal/greyproxy/migrations.go @@ -139,6 +139,36 @@ var migrations = []string{ // Migration 7: Add mitm_skip_reason column to request_logs for tracking why MITM was skipped `ALTER TABLE request_logs ADD COLUMN mitm_skip_reason TEXT;`, + + // Migration 8: Create sessions and global_credentials tables for credential substitution + `CREATE TABLE IF NOT EXISTS sessions ( + session_id TEXT PRIMARY KEY, + container_name TEXT NOT NULL, + mappings_enc BLOB NOT NULL, + labels_json TEXT NOT NULL DEFAULT '{}', + ttl_seconds INTEGER NOT NULL DEFAULT 900, + created_at DATETIME NOT NULL DEFAULT (datetime('now')), + expires_at DATETIME NOT NULL, + last_heartbeat DATETIME NOT NULL DEFAULT (datetime('now')), + substitution_count INTEGER NOT NULL DEFAULT 0 + ); + CREATE INDEX IF NOT EXISTS idx_sessions_expires_at ON sessions(expires_at); + CREATE INDEX IF NOT EXISTS idx_sessions_container ON sessions(container_name); + + CREATE TABLE IF NOT EXISTS global_credentials ( + id TEXT PRIMARY KEY, + label TEXT NOT NULL UNIQUE, + placeholder TEXT NOT NULL UNIQUE, + value_enc BLOB NOT NULL, + value_preview TEXT NOT NULL, + created_at DATETIME NOT NULL DEFAULT (datetime('now')) + );`, + + // Migration 9: Add credential substitution tracking, session metadata, and transaction-session linking + `ALTER TABLE http_transactions ADD COLUMN substituted_credentials TEXT DEFAULT NULL; + ALTER TABLE http_transactions ADD COLUMN session_id TEXT DEFAULT NULL; + CREATE INDEX IF NOT EXISTS idx_http_transactions_session ON http_transactions(session_id); + ALTER TABLE sessions ADD COLUMN metadata_json TEXT NOT NULL DEFAULT '{}';`, } func runMigrations(db *sql.DB) error { diff --git a/internal/greyproxy/models.go b/internal/greyproxy/models.go index 4c889b7..1005fef 100644 --- a/internal/greyproxy/models.go +++ b/internal/greyproxy/models.go @@ -185,47 +185,51 @@ func (l *RequestLog) DisplayHost() string { // HttpTransaction represents a MITM-captured HTTP request/response pair. type HttpTransaction struct { - ID int64 `json:"id"` - Timestamp time.Time `json:"timestamp"` - ContainerName string `json:"container_name"` - DestinationHost string `json:"destination_host"` - DestinationPort int `json:"destination_port"` - Method string `json:"method"` - URL string `json:"url"` - RequestHeaders sql.NullString `json:"-"` - RequestBody []byte `json:"-"` - RequestBodySize sql.NullInt64 `json:"-"` - RequestContentType sql.NullString `json:"-"` - StatusCode sql.NullInt64 `json:"status_code"` - ResponseHeaders sql.NullString `json:"-"` - ResponseBody []byte `json:"-"` - ResponseBodySize sql.NullInt64 `json:"-"` - ResponseContentType sql.NullString `json:"-"` - DurationMs sql.NullInt64 `json:"duration_ms"` - RuleID sql.NullInt64 `json:"rule_id"` - Result string `json:"result"` + ID int64 `json:"id"` + Timestamp time.Time `json:"timestamp"` + ContainerName string `json:"container_name"` + DestinationHost string `json:"destination_host"` + DestinationPort int `json:"destination_port"` + Method string `json:"method"` + URL string `json:"url"` + RequestHeaders sql.NullString `json:"-"` + RequestBody []byte `json:"-"` + RequestBodySize sql.NullInt64 `json:"-"` + RequestContentType sql.NullString `json:"-"` + StatusCode sql.NullInt64 `json:"status_code"` + ResponseHeaders sql.NullString `json:"-"` + ResponseBody []byte `json:"-"` + ResponseBodySize sql.NullInt64 `json:"-"` + ResponseContentType sql.NullString `json:"-"` + DurationMs sql.NullInt64 `json:"duration_ms"` + RuleID sql.NullInt64 `json:"rule_id"` + Result string `json:"result"` + SubstitutedCredentials sql.NullString `json:"-"` + SessionID sql.NullString `json:"-"` } type HttpTransactionJSON struct { - ID int64 `json:"id"` - Timestamp string `json:"timestamp"` - ContainerName string `json:"container_name"` - DestinationHost string `json:"destination_host"` - DestinationPort int `json:"destination_port"` - Method string `json:"method"` - URL string `json:"url"` - RequestHeaders any `json:"request_headers,omitempty"` - RequestBody *string `json:"request_body,omitempty"` - RequestBodySize *int64 `json:"request_body_size,omitempty"` - RequestContentType *string `json:"request_content_type,omitempty"` - StatusCode *int64 `json:"status_code,omitempty"` - ResponseHeaders any `json:"response_headers,omitempty"` - ResponseBody *string `json:"response_body,omitempty"` - ResponseBodySize *int64 `json:"response_body_size,omitempty"` - ResponseContentType *string `json:"response_content_type,omitempty"` - DurationMs *int64 `json:"duration_ms,omitempty"` - RuleID *int64 `json:"rule_id,omitempty"` - Result string `json:"result"` + ID int64 `json:"id"` + Timestamp string `json:"timestamp"` + ContainerName string `json:"container_name"` + DestinationHost string `json:"destination_host"` + DestinationPort int `json:"destination_port"` + Method string `json:"method"` + URL string `json:"url"` + RequestHeaders any `json:"request_headers,omitempty"` + RequestBody *string `json:"request_body,omitempty"` + RequestBodySize *int64 `json:"request_body_size,omitempty"` + RequestContentType *string `json:"request_content_type,omitempty"` + StatusCode *int64 `json:"status_code,omitempty"` + ResponseHeaders any `json:"response_headers,omitempty"` + ResponseBody *string `json:"response_body,omitempty"` + ResponseBodySize *int64 `json:"response_body_size,omitempty"` + ResponseContentType *string `json:"response_content_type,omitempty"` + DurationMs *int64 `json:"duration_ms,omitempty"` + RuleID *int64 `json:"rule_id,omitempty"` + Result string `json:"result"` + SubstitutedCredentials []string `json:"substituted_credentials,omitempty"` + SessionID *string `json:"session_id,omitempty"` } func (t *HttpTransaction) ToJSON(includeBody bool) HttpTransactionJSON { @@ -282,26 +286,37 @@ func (t *HttpTransaction) ToJSON(includeBody bool) HttpTransactionJSON { j.ResponseBody = &s } } + if t.SubstitutedCredentials.Valid && t.SubstitutedCredentials.String != "" { + var creds []string + if json.Unmarshal([]byte(t.SubstitutedCredentials.String), &creds) == nil { + j.SubstitutedCredentials = creds + } + } + if t.SessionID.Valid { + j.SessionID = &t.SessionID.String + } return j } // HttpTransactionCreateInput holds the data needed to create a transaction record. type HttpTransactionCreateInput struct { - ContainerName string - DestinationHost string - DestinationPort int - Method string - URL string - RequestHeaders http.Header - RequestBody []byte - RequestContentType string - StatusCode int - ResponseHeaders http.Header - ResponseBody []byte - ResponseContentType string - DurationMs int64 - RuleID *int64 - Result string + ContainerName string + DestinationHost string + DestinationPort int + Method string + URL string + RequestHeaders http.Header + RequestBody []byte + RequestContentType string + StatusCode int + ResponseHeaders http.Header + ResponseBody []byte + ResponseContentType string + DurationMs int64 + RuleID *int64 + Result string + SubstitutedCredentials []string + SessionID string } // DashboardStats holds aggregated data for the dashboard. @@ -343,3 +358,81 @@ type TimelinePoint struct { Allowed int `json:"allowed"` Blocked int `json:"blocked"` } + +// Session represents a credential substitution session registered by greywall. +type Session struct { + SessionID string `json:"session_id"` + ContainerName string `json:"container_name"` + MappingsEnc []byte `json:"-"` + LabelsJSON string `json:"-"` + MetadataJSON string `json:"-"` + TTLSeconds int `json:"ttl_seconds"` + CreatedAt time.Time `json:"created_at"` + ExpiresAt time.Time `json:"expires_at"` + LastHeartbeat time.Time `json:"last_heartbeat"` + SubstitutionCount int64 `json:"substitution_count"` +} + +type SessionJSON struct { + SessionID string `json:"session_id"` + ContainerName string `json:"container_name"` + CredentialCount int `json:"credential_count"` + CredentialLabels []string `json:"credential_labels"` + Metadata map[string]string `json:"metadata,omitempty"` + TTLSeconds int `json:"ttl_seconds"` + CreatedAt string `json:"created_at"` + ExpiresAt string `json:"expires_at"` + LastHeartbeat string `json:"last_heartbeat"` + SubstitutionCount int64 `json:"substitution_count"` +} + +func (s *Session) ToJSON(labels map[string]string) SessionJSON { + labelList := make([]string, 0, len(labels)) + for _, v := range labels { + labelList = append(labelList, v) + } + var metadata map[string]string + if s.MetadataJSON != "" && s.MetadataJSON != "{}" { + json.Unmarshal([]byte(s.MetadataJSON), &metadata) + } + return SessionJSON{ + SessionID: s.SessionID, + ContainerName: s.ContainerName, + CredentialCount: len(labels), + CredentialLabels: labelList, + Metadata: metadata, + TTLSeconds: s.TTLSeconds, + CreatedAt: s.CreatedAt.UTC().Format(time.RFC3339), + ExpiresAt: s.ExpiresAt.UTC().Format(time.RFC3339), + LastHeartbeat: s.LastHeartbeat.UTC().Format(time.RFC3339), + SubstitutionCount: s.SubstitutionCount, + } +} + +// GlobalCredential represents a persistent credential configured via the dashboard. +type GlobalCredential struct { + ID string `json:"id"` + Label string `json:"label"` + Placeholder string `json:"placeholder"` + ValueEnc []byte `json:"-"` + ValuePreview string `json:"value_preview"` + CreatedAt time.Time `json:"created_at"` +} + +type GlobalCredentialJSON struct { + ID string `json:"id"` + Label string `json:"label"` + Placeholder string `json:"placeholder"` + ValuePreview string `json:"value_preview"` + CreatedAt string `json:"created_at"` +} + +func (g *GlobalCredential) ToJSON() GlobalCredentialJSON { + return GlobalCredentialJSON{ + ID: g.ID, + Label: g.Label, + Placeholder: g.Placeholder, + ValuePreview: g.ValuePreview, + CreatedAt: g.CreatedAt.UTC().Format(time.RFC3339), + } +} diff --git a/internal/greyproxy/ui/pages.go b/internal/greyproxy/ui/pages.go index d8d6f39..8cccbb3 100644 --- a/internal/greyproxy/ui/pages.go +++ b/internal/greyproxy/ui/pages.go @@ -171,6 +171,11 @@ var funcMap = template.FuncMap{ "isExpired": func(t time.Time) bool { return time.Now().After(t) }, + "credLabels": func(raw string) []string { + var labels []string + json.Unmarshal([]byte(raw), &labels) + return labels + }, "derefStr": func(s *string) string { if s == nil { return "" @@ -856,11 +861,13 @@ func RegisterHTMXRoutes(r *gin.RouterGroup, db *greyproxy.DB, bus *greyproxy.Eve container := c.Query("container") destination := c.Query("destination") method := c.Query("method") + sessionID := c.Query("session_id") f := greyproxy.TransactionFilter{ Container: container, Destination: destination, Method: method, + SessionID: sessionID, Limit: limit, Offset: offset, } diff --git a/internal/greyproxy/ui/templates/base.html b/internal/greyproxy/ui/templates/base.html index 04cf767..e8c3289 100644 --- a/internal/greyproxy/ui/templates/base.html +++ b/internal/greyproxy/ui/templates/base.html @@ -309,6 +309,9 @@ if (msg.type === 'maintenance.progress') { window.dispatchEvent(new CustomEvent('proxy:maintenance-event', { detail: msg })); } + if (msg.type && msg.type.indexOf('session.') === 0) { + window.dispatchEvent(new CustomEvent('proxy:session-event', { detail: msg })); + } } catch (e) {} }; diff --git a/internal/greyproxy/ui/templates/partials/activity_table.html b/internal/greyproxy/ui/templates/partials/activity_table.html index be3f65f..7a02a9e 100644 --- a/internal/greyproxy/ui/templates/partials/activity_table.html +++ b/internal/greyproxy/ui/templates/partials/activity_table.html @@ -76,7 +76,7 @@ {{.DestinationHost}}{{if gt .DestinationPort 0}}:{{.DestinationPort}}{{end}} {{end}} {{else}} - {{truncate .URL.String 80}} + {{if .SubstitutedCredentials.Valid}}{{end}}{{truncate .URL.String 80}} {{end}} @@ -126,6 +126,12 @@ {{if .RuleSummary.Valid}} ({{.RuleSummary.String}}){{end}} {{end}} + {{if .SubstitutedCredentials.Valid}} +
+ Credentials substituted: + {{range credLabels .SubstitutedCredentials.String}}{{.}}{{end}} +
+ {{end}}
Loading request/response details... diff --git a/internal/greyproxy/ui/templates/partials/traffic_table.html b/internal/greyproxy/ui/templates/partials/traffic_table.html index c2337d6..9d0f3b3 100644 --- a/internal/greyproxy/ui/templates/partials/traffic_table.html +++ b/internal/greyproxy/ui/templates/partials/traffic_table.html @@ -44,7 +44,7 @@ {{.ContainerName}} - {{.URL}} + {{if .SubstitutedCredentials.Valid}}{{end}}{{.URL}} {{formatTime .Timestamp}} {{if .DurationMs.Valid}}{{.DurationMs.Int64}}ms{{end}} @@ -57,6 +57,7 @@ {{if .RequestBodySize.Valid}}
Request size: {{.RequestBodySize.Int64}} bytes
{{end}} {{if .ResponseBodySize.Valid}}
Response size: {{.ResponseBodySize.Int64}} bytes
{{end}} {{if .RequestContentType.Valid}}
Content-Type: {{.RequestContentType.String}}
{{end}} + {{if .SubstitutedCredentials.Valid}}
Credentials substituted: {{range credLabels .SubstitutedCredentials.String}}{{.}}{{end}}
{{end}}
Click to load request/response details... diff --git a/internal/greyproxy/ui/templates/settings.html b/internal/greyproxy/ui/templates/settings.html index 0c11d30..d6a7a4c 100644 --- a/internal/greyproxy/ui/templates/settings.html +++ b/internal/greyproxy/ui/templates/settings.html @@ -4,6 +4,25 @@

Settings

+ +
+ + + +
+ + +
+

Appearance

@@ -77,6 +96,70 @@

Notifications

+ +
+ + +
+ +
+ + + - -
- - +
{{end}} @@ -225,6 +320,35 @@

Maintenance

var notifEnabled = false; var notifAvailable = false; + // --- Tab switching --- + window.switchTab = function(tab) { + document.querySelectorAll('.settings-tab').forEach(function(btn) { + if (btn.getAttribute('data-tab') === tab) { + btn.classList.add('bg-primary', 'text-primary-foreground'); + btn.classList.remove('text-muted-foreground', 'hover:bg-muted'); + } else { + btn.classList.remove('bg-primary', 'text-primary-foreground'); + btn.classList.add('text-muted-foreground', 'hover:bg-muted'); + } + }); + ['general', 'mitm', 'credentials'].forEach(function(t) { + var panel = document.getElementById('tab-' + t); + if (t === tab) { + panel.classList.remove('hidden'); + } else { + panel.classList.add('hidden'); + } + }); + history.replaceState(null, '', '#' + tab); + }; + + // Restore tab from URL hash + var hash = window.location.hash.replace('#', ''); + if (['general', 'mitm', 'credentials'].indexOf(hash) !== -1) { + switchTab(hash); + } + + // --- Theme --- function updateThemeButtons(theme) { currentTheme = theme; document.querySelectorAll('.theme-btn').forEach(function(btn) { @@ -260,6 +384,7 @@

Maintenance

}).catch(function() {}); }; + // --- Notifications --- function updateNotifToggle(enabled, available) { notifEnabled = enabled; notifAvailable = available; @@ -336,6 +461,7 @@

Maintenance

.catch(function() {}); }; + // --- Advanced / Maintenance --- window.redactHeaders = function() { var btn = document.getElementById('redact-btn'); var status = document.getElementById('redact-status'); @@ -372,7 +498,6 @@

Maintenance

}); }; - // Listen for redaction progress events from the WebSocket window.addEventListener('proxy:maintenance-event', function(e) { var data = e.detail.data; if (!data || data.task !== 'redact_headers') return; @@ -475,6 +600,19 @@

Maintenance

dot.classList.remove('translate-x-5'); dot.classList.add('translate-x-1'); } + // Update credential protection status banner + updateCredStatusBanner(enabled); + } + + function updateCredStatusBanner(mitmOn) { + var httpEl = document.getElementById('cred-http-status'); + var httpsEl = document.getElementById('cred-https-status'); + httpEl.innerHTML = 'HTTP: protected'; + if (mitmOn) { + httpsEl.innerHTML = 'HTTPS: protected'; + } else { + httpsEl.innerHTML = 'HTTPS: not protected'; + } } window.toggleMitmSetting = function() { @@ -530,20 +668,20 @@

Maintenance

text.innerHTML += ' · expires ' + data.expiresAt.split('T')[0] + ''; } document.getElementById('cert-install-section').classList.remove('hidden'); - // Show install command for the current platform var cmd = data.installCommands.linux || data.installCommands.macos || ''; document.getElementById('cert-install-cmd').textContent = cmd; } - // Download link document.getElementById('cert-download-link').href = prefix + '/api/cert/download'; - } function escapeHtml(text) { - var div = document.createElement('div'); - div.appendChild(document.createTextNode(text)); - return div.innerHTML; + return String(text) + .replace(/&/g, '&') + .replace(//g, '>') + .replace(/"/g, '"') + .replace(/'/g, '''); } window.generateCert = function(force) { @@ -605,7 +743,6 @@

Maintenance

window.copyInstallCmd = function() { var cmd = document.getElementById('cert-install-cmd').textContent; navigator.clipboard.writeText(cmd).then(function() { - // Brief visual feedback var btn = document.querySelector('#cert-install-section button[onclick="copyInstallCmd()"]'); if (btn) { btn.innerHTML = ''; @@ -618,6 +755,188 @@

Maintenance

loadCertStatus(); + // --- Time helpers --- + function formatRelativeTime(date) { + var seconds = Math.floor((Date.now() - date.getTime()) / 1000); + if (seconds < 60) return 'just now'; + var minutes = Math.floor(seconds / 60); + if (minutes < 60) return minutes + 'm ago'; + var hours = Math.floor(minutes / 60); + if (hours < 24) return hours + 'h ago'; + var days = Math.floor(hours / 24); + return days + 'd ago'; + } + + function formatDuration(from, to) { + var seconds = Math.floor((to.getTime() - from.getTime()) / 1000); + if (seconds < 60) return seconds + 's'; + var minutes = Math.floor(seconds / 60); + if (minutes < 60) return minutes + 'm'; + var hours = Math.floor(minutes / 60); + var remainMin = minutes % 60; + if (hours < 24) return hours + 'h ' + remainMin + 'm'; + var days = Math.floor(hours / 24); + var remainHours = hours % 24; + return days + 'd ' + remainHours + 'h'; + } + + // --- Credential Protection --- + function loadSessions() { + fetch(prefix + '/api/sessions') + .then(function(r) { return r.json(); }) + .then(function(sessions) { + var el = document.getElementById('sessions-list'); + if (!sessions || sessions.length === 0) { + el.innerHTML = '

No active sessions. Sessions are created automatically when greywall launches a sandbox.

'; + return; + } + var html = '
'; + sessions.forEach(function(s) { + html += '
'; + html += '
'; + html += '' + escapeHtml(s.container_name) + ''; + html += ''; + html += '
'; + // Created / duration + html += '
'; + if (s.created_at) { + var created = new Date(s.created_at); + html += 'Created ' + formatRelativeTime(created) + ''; + html += '·'; + html += 'Active for ' + formatDuration(created, new Date()) + ''; + } + html += '
'; + // Credential labels + html += '
'; + (s.credential_labels || []).forEach(function(label) { + html += '' + escapeHtml(label) + ''; + }); + html += '
'; + // Stats + html += '
'; + html += '' + s.credential_count + ' credential' + (s.credential_count !== 1 ? 's' : '') + ''; + html += '' + s.substitution_count + ' substitution' + (s.substitution_count !== 1 ? 's' : '') + ''; + html += '
'; + // Metadata (if present) + if (s.metadata && Object.keys(s.metadata).length > 0) { + html += '
'; + html += '
'; + var metaKeys = Object.keys(s.metadata); + metaKeys.forEach(function(key) { + html += '' + escapeHtml(key) + ''; + html += '' + escapeHtml(s.metadata[key]) + ''; + }); + html += '
'; + html += '
'; + } + html += '
'; + }); + html += '
'; + el.innerHTML = html; + }) + .catch(function() { + document.getElementById('sessions-list').innerHTML = '

Failed to load sessions.

'; + }); + } + + function loadCredentials() { + fetch(prefix + '/api/credentials') + .then(function(r) { return r.json(); }) + .then(function(creds) { + var el = document.getElementById('credentials-list'); + if (!creds || creds.length === 0) { + el.innerHTML = '

No global credentials configured.

'; + return; + } + var html = '
'; + creds.forEach(function(c) { + html += '
'; + html += '
'; + html += '' + escapeHtml(c.label) + ''; + html += '' + escapeHtml(c.value_preview) + ''; + html += '
'; + html += ''; + html += '
'; + }); + html += '
'; + el.innerHTML = html; + }) + .catch(function() { + document.getElementById('credentials-list').innerHTML = '

Failed to load credentials.

'; + }); + } + + window.addCredential = function() { + var label = document.getElementById('cred-label').value.trim(); + var value = document.getElementById('cred-value').value; + var errorEl = document.getElementById('cred-add-error'); + errorEl.classList.add('hidden'); + + if (!label || !value) { + errorEl.textContent = 'Both label and value are required.'; + errorEl.classList.remove('hidden'); + return; + } + + var btn = document.getElementById('cred-add-btn'); + btn.disabled = true; + btn.classList.add('opacity-50'); + + fetch(prefix + '/api/credentials', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ label: label, value: value }) + }) + .then(function(r) { return r.json().then(function(d) { return { ok: r.ok, data: d }; }); }) + .then(function(res) { + btn.disabled = false; + btn.classList.remove('opacity-50'); + if (!res.ok) { + errorEl.textContent = res.data.error || 'Failed to add credential.'; + errorEl.classList.remove('hidden'); + } else { + document.getElementById('cred-label').value = ''; + document.getElementById('cred-value').value = ''; + loadCredentials(); + showToast('Credential added: ' + label); + } + }) + .catch(function() { + btn.disabled = false; + btn.classList.remove('opacity-50'); + errorEl.textContent = 'Network error.'; + errorEl.classList.remove('hidden'); + }); + }; + + function deleteCredential(id) { + fetch(prefix + '/api/credentials/' + encodeURIComponent(id), { method: 'DELETE' }) + .then(function() { loadCredentials(); }) + .catch(function() {}); + } + + function deleteSession(id) { + fetch(prefix + '/api/sessions/' + encodeURIComponent(id), { method: 'DELETE' }) + .then(function() { loadSessions(); }) + .catch(function() {}); + } + + // Event delegation for dynamically rendered buttons + document.addEventListener('click', function(e) { + var btn = e.target.closest('[data-delete-session]'); + if (btn) { deleteSession(btn.getAttribute('data-delete-session')); return; } + btn = e.target.closest('[data-delete-credential]'); + if (btn) { deleteCredential(btn.getAttribute('data-delete-credential')); return; } + }); + + loadSessions(); + loadCredentials(); + + // Listen for session events to refresh + window.addEventListener('proxy:session-event', function() { + loadSessions(); + }); + // Load settings from server fetch(prefix + '/api/settings') .then(function(r) { return r.json(); }) diff --git a/internal/greyproxy/ui/templates/traffic.html b/internal/greyproxy/ui/templates/traffic.html index cabc58e..d9542e4 100644 --- a/internal/greyproxy/ui/templates/traffic.html +++ b/internal/greyproxy/ui/templates/traffic.html @@ -111,6 +111,15 @@

HTTP Traffic

.then(function(txn) { var bodyEl = document.getElementById('txn-body-' + id); var html = ''; + if (txn.substituted_credentials && txn.substituted_credentials.length > 0) { + html += '
'; + html += ''; + html += 'Credentials substituted:'; + txn.substituted_credentials.forEach(function(label) { + html += '' + escapeHtml(label) + ''; + }); + html += '
'; + } if (txn.request_headers) { html += '
Request Headers
' + escapeHtml(JSON.stringify(txn.request_headers, null, 2)) + '
'; }