diff --git a/go.mod b/go.mod index 8996fabb..534c4920 100644 --- a/go.mod +++ b/go.mod @@ -58,6 +58,8 @@ require ( github.com/Masterminds/semver/v3 v3.3.0 // indirect github.com/alessio/shellescape v1.4.1 // indirect github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef // indirect + github.com/awnumar/memcall v0.4.0 // indirect + github.com/awnumar/memguard v0.23.0 // indirect github.com/aws/aws-sdk-go-v2 v1.27.2 // indirect github.com/aws/aws-sdk-go-v2/config v1.27.18 // indirect github.com/aws/aws-sdk-go-v2/credentials v1.17.18 // indirect diff --git a/go.sum b/go.sum index e9f236ba..e929ea5f 100644 --- a/go.sum +++ b/go.sum @@ -74,6 +74,10 @@ github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmV github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef h1:46PFijGLmAjMPwCCCo7Jf0W6f9slllCkkv7vyc1yOSg= github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= +github.com/awnumar/memcall v0.4.0 h1:B7hgZYdfH6Ot1Goaz8jGne/7i8xD4taZie/PNSFZ29g= +github.com/awnumar/memcall v0.4.0/go.mod h1:8xOx1YbfyuCg3Fy6TO8DK0kZUua3V42/goA5Ru47E8w= +github.com/awnumar/memguard v0.23.0 h1:sJ3a1/SWlcuKIQ7MV+R9p0Pvo9CWsMbGZvcZQtmc68A= +github.com/awnumar/memguard v0.23.0/go.mod h1:olVofBrsPdITtJ2HgxQKrEYEMyIBAIciVG4wNnZhW9M= github.com/aws/aws-sdk-go-v2 v1.27.2 h1:pLsTXqX93rimAOZG2FIYraDQstZaaGVVN4tNw65v0h8= github.com/aws/aws-sdk-go-v2 v1.27.2/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM= github.com/aws/aws-sdk-go-v2/config v1.27.18 h1:wFvAnwOKKe7QAyIxziwSKjmer9JBMH1vzIL6W+fYuKk= diff --git a/packages/cmd/agent.go b/packages/cmd/agent.go index 1f094c86..fb1b0647 100644 --- a/packages/cmd/agent.go +++ b/packages/cmd/agent.go @@ -28,6 +28,7 @@ import ( "text/template" "time" + "github.com/awnumar/memguard" "github.com/dgraph-io/badger/v3" "github.com/go-resty/resty/v2" infisicalSdk "github.com/infisical/go-sdk" @@ -90,13 +91,13 @@ type RetryConfig struct { } type Config struct { - Version string `yaml:"version,omitempty"` - Infisical InfisicalConfig `yaml:"infisical"` - Auth AuthConfig `yaml:"auth"` - Sinks []Sink `yaml:"sinks"` - Cache CacheConfig `yaml:"cache,omitempty"` - Templates []Template `yaml:"templates"` - Certificates []AgentCertificateConfig `yaml:"certificates,omitempty"` + Version string `yaml:"version,omitempty"` + Infisical InfisicalConfig `yaml:"infisical"` + Auth AuthConfig `yaml:"auth"` + Sinks []Sink `yaml:"sinks"` + Cache CacheConfig `yaml:"cache,omitempty"` + Templates []Template `yaml:"templates"` + Certificates []AgentCertificateConfig `yaml:"certificates,omitempty"` } type TemplateWithID struct { @@ -195,10 +196,10 @@ type Template struct { } type CertificateLifecycleConfig struct { - RenewBeforeExpiry string `yaml:"renew-before-expiry"` - StatusCheckInterval string `yaml:"status-check-interval"` + RenewBeforeExpiry string `yaml:"renew-before-expiry"` + StatusCheckInterval string `yaml:"status-check-interval"` FailureRetryInterval string `yaml:"failure-retry-interval,omitempty"` - MaxFailureRetries int `yaml:"max-failure-retries,omitempty"` + MaxFailureRetries int `yaml:"max-failure-retries,omitempty"` } type CertificateAttributes struct { @@ -343,7 +344,10 @@ func NewCacheManager(ctx context.Context, cacheConfig *CacheConfig) (*CacheManag return &CacheManager{}, fmt.Errorf("unable to read service account token: %v. Please ensure the file exists and is not empty", err) } - encryptionKey := sha256.Sum256(serviceAccountToken) + hash := sha256.Sum256(serviceAccountToken) + encryptionKey := memguard.NewBufferFromBytes(hash[:]) // the hash (source) is wiped after copied to the secure buffer + + defer encryptionKey.Destroy() cacheStorage, err := cache.NewEncryptedStorage(cache.EncryptedStorageOptions{ DBPath: cacheConfig.Persistent.Path, @@ -2000,7 +2004,6 @@ func validateCertificateLifecycleConfig(certificates *[]AgentCertificateConfig) return nil } - func resolveCertificateNameReferences(certificates *[]AgentCertificateConfig, httpClient *resty.Client) error { for i := range *certificates { cert := &(*certificates)[i] @@ -2086,7 +2089,6 @@ func buildCertificateAttributes(certificate *AgentCertificateConfig) *api.Certif removeRoots = false } - attributes.RemoveRootsFromChain = removeRoots hasAny = true @@ -3207,7 +3209,6 @@ var agentCmd = &cobra.Command{ log.Warn().Msg("credential revocation timed out after 5 minutes, forcing exit") exitCode = 1 } - } os.Exit(exitCode) diff --git a/packages/cmd/proxy.go b/packages/cmd/proxy.go new file mode 100644 index 00000000..156b127f --- /dev/null +++ b/packages/cmd/proxy.go @@ -0,0 +1,585 @@ +package cmd + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "os" + "os/signal" + "strings" + "syscall" + "time" + + "github.com/Infisical/infisical-merge/packages/proxy" + "github.com/Infisical/infisical-merge/packages/util" + "github.com/Infisical/infisical-merge/packages/util/cache" + "github.com/awnumar/memguard" + "github.com/rs/zerolog/log" + "github.com/spf13/cobra" +) + +type CacheEvictionStrategy string + +const ( + CacheEvictionStrategyOptimistic CacheEvictionStrategy = "optimistic" +) + +var proxyCmd = &cobra.Command{ + Example: `infisical proxy start`, + Short: "Used to run Infisical proxy server", + Use: "proxy", + DisableFlagsInUseLine: true, + Args: cobra.NoArgs, +} + +var proxyStartCmd = &cobra.Command{ + Example: `infisical proxy start --domain=https://app.infisical.com --listen-address=localhost:8081`, + Short: "Start the Infisical proxy server", + Use: "start", + DisableFlagsInUseLine: true, + Args: cobra.NoArgs, + Run: startProxyServer, +} + +var proxyDebugCmd = &cobra.Command{ + Example: `infisical proxy debug --listen-address=localhost:8081`, + Short: "Print cache debug information (dev mode only)", + Use: "debug", + DisableFlagsInUseLine: true, + Args: cobra.NoArgs, + Run: printCacheDebug, + Hidden: true, +} + +func startProxyServer(cmd *cobra.Command, args []string) { + domain, err := cmd.Flags().GetString("domain") + if err != nil { + util.HandleError(err, "Unable to parse domain flag") + } + + if domain == "" { + util.PrintErrorMessageAndExit("Domain flag is required") + } + + listenAddress, err := cmd.Flags().GetString("listen-address") + if err != nil { + util.HandleError(err, "Unable to parse listen-address flag") + } + + tlsEnabled, err := cmd.Flags().GetBool("tls-enabled") + if err != nil { + util.HandleError(err, "Unable to parse tls-enabled flag") + } + + tlsCertFile, err := cmd.Flags().GetString("tls-cert-file") + if err != nil { + util.HandleError(err, "Unable to parse tls-cert-file flag") + } + + tlsKeyFile, err := cmd.Flags().GetString("tls-key-file") + if err != nil { + util.HandleError(err, "Unable to parse tls-key-file flag") + } + + if tlsEnabled && (tlsCertFile == "" || tlsKeyFile == "") { + util.PrintErrorMessageAndExit("`tls-cert-file` and `tls-key-file` are required when `tls-enabled` is set to true") + } + + if listenAddress == "" { + util.PrintErrorMessageAndExit("Listen-address flag is required") + } + + evictionStrategy, err := cmd.Flags().GetString("eviction-strategy") + if err != nil { + util.HandleError(err, "Unable to parse eviction-strategy flag") + } + + if evictionStrategy != string(CacheEvictionStrategyOptimistic) { + util.PrintErrorMessageAndExit(fmt.Sprintf("Invalid eviction-strategy '%s'. Currently only 'optimistic' is supported.", evictionStrategy)) + } + + accessTokenCheckIntervalStr, err := cmd.Flags().GetString("access-token-check-interval") + if err != nil { + util.HandleError(err, "Unable to parse access-token-check-interval flag") + } + + accessTokenCheckInterval, err := util.ParseTimeDurationString(accessTokenCheckIntervalStr, true) + if err != nil { + util.PrintErrorMessageAndExit(fmt.Sprintf("Invalid access-token-check-interval format '%s'. Use formats like 5m, 1h, 1d", accessTokenCheckIntervalStr)) + } + + staticSecretsRefreshIntervalStr, err := cmd.Flags().GetString("static-secrets-refresh-interval") + if err != nil { + util.HandleError(err, "Unable to parse static-secrets-refresh-interval flag") + } + + staticSecretsRefreshInterval, err := util.ParseTimeDurationString(staticSecretsRefreshIntervalStr, true) + if err != nil { + util.PrintErrorMessageAndExit(fmt.Sprintf("Invalid static-secrets-refresh-interval format '%s'. Use formats like 30m, 1h, 1d", staticSecretsRefreshIntervalStr)) + } + + domainURL, err := url.Parse(domain) + if err != nil { + util.HandleError(err, fmt.Sprintf("Invalid domain URL: %s", domain)) + } + + httpClient := &http.Client{ + Timeout: 30 * time.Second, + } + + // Create a separate client for streaming endpoints (no timeout for long-lived connections) + streamingClient := &http.Client{ + Timeout: 0, + } + + // Create in-memory cache (no persistence, no encryption needed for ephemeral data) + // For persistent cache with encryption, use proxy.NewCacheWithOptions + encryptionKey := memguard.NewBufferRandom(32) + defer encryptionKey.Destroy() + + cache, err := proxy.NewCache(cache.EncryptedStorageOptions{ + InMemory: true, + EncryptionKey: encryptionKey, + }) + + if err != nil { + util.PrintErrorMessageAndExit(fmt.Sprintf("Failed to create cache: %v", err)) + } + + defer cache.Close() + + mux := http.NewServeMux() + + // Debug endpoint (dev mode only) + if util.IsDevelopmentMode() { + mux.HandleFunc("/_debug/cache", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + debugInfo := cache.GetDebugInfo() + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(debugInfo); err != nil { + log.Error().Err(err).Msg("Failed to encode cache debug info") + http.Error(w, "Failed to encode debug info", http.StatusInternalServerError) + return + } + }) + log.Info().Msg("Dev mode enabled: debug endpoint available at /_debug/cache") + } + + proxyHandler := func(w http.ResponseWriter, r *http.Request) { + // Skip debug endpoints - they're handled by mux + if strings.HasPrefix(r.URL.Path, "/_debug/") { + http.NotFound(w, r) + return + } + + token := proxy.ExtractTokenFromRequest(r) + + isCacheable := proxy.IsCacheableRequest(r.URL.Path, r.Method) + isStreaming := isStreamingEndpoint(r.URL.Path) + + // -- Cache Check -- + + if isCacheable && token != "" { + cacheKey := proxy.GenerateCacheKey(r.Method, r.URL.Path, r.URL.RawQuery, token) + + if cachedResp, found := cache.Get(cacheKey); found { + log.Info(). + Str("hash", cacheKey). + Msg("Cache hit") + + proxy.CopyHeaders(w.Header(), cachedResp.Header) + w.WriteHeader(cachedResp.StatusCode) + _, err := io.Copy(w, cachedResp.Body) + if err != nil { + log.Error().Err(err).Msg("Failed to copy cached response body") + return + } + return + } + + log.Info(). + Str("hash", cacheKey). + Msg("Cache miss") + } + + // -- Proxy Request -- + + // Read request body for mutation eviction (PATCH/DELETE) or restore for forwarding + var requestBodyBytes []byte + if r.Body != nil { + requestBodyBytes, err = io.ReadAll(r.Body) + if err != nil { + log.Error().Err(err).Msg("Failed to read request body") + http.Error(w, fmt.Sprintf("Failed to read request body: %v", err), http.StatusInternalServerError) + return + } + } + + targetURL := *domainURL + targetURL.Path = domainURL.Path + r.URL.Path + targetURL.RawQuery = r.URL.RawQuery + + var bodyReader io.Reader + if requestBodyBytes != nil { + bodyReader = bytes.NewReader(requestBodyBytes) + } + + proxyReq, err := http.NewRequest(r.Method, targetURL.String(), bodyReader) + if err != nil { + log.Error().Err(err).Msg("Failed to create proxy request") + http.Error(w, fmt.Sprintf("Failed to create proxy request: %v", err), http.StatusInternalServerError) + return + } + + proxy.CopyHeaders(proxyReq.Header, r.Header) + + log.Debug(). + Str("method", r.Method). + Str("path", r.URL.Path). + Str("target", targetURL.String()). + Msg("Forwarding request") + + // Use streaming client for SSE/streaming endpoints, regular client for others + clientToUse := httpClient + if isStreaming { + clientToUse = streamingClient + } + + resp, err := clientToUse.Do(proxyReq) + if err != nil { + log.Error().Err(err).Msg("Failed to forward request") + http.Error(w, fmt.Sprintf("Failed to forward request: %v", err), http.StatusBadGateway) + return + } + defer resp.Body.Close() + + // -- Proxy Response -- + + proxy.CopyHeaders(w.Header(), resp.Header) + w.WriteHeader(resp.StatusCode) + + // For streaming endpoints, stream directly instead of buffering + if isStreaming { + // Flush headers immediately for SSE + if flusher, ok := w.(http.Flusher); ok { + flusher.Flush() + } + + // Stream with periodic flushing for SSE events + buf := make([]byte, 1024) + for { + n, err := resp.Body.Read(buf) + if n > 0 { + if _, writeErr := w.Write(buf[:n]); writeErr != nil { + log.Error().Err(writeErr).Msg("Failed to write streaming response") + return + } + // Flush after each write to send SSE events immediately + if flusher, ok := w.(http.Flusher); ok { + flusher.Flush() + } + } + if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + break + } + if err != nil { + log.Error().Err(err).Msg("Failed to read streaming response") + return + } + } + return + } + + // For non-streaming endpoints, read into memory for caching and serving + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + log.Error().Err(err).Msg("Failed to read response body") + http.Error(w, fmt.Sprintf("Failed to read response body: %v", err), http.StatusInternalServerError) + return + } + + _, err = w.Write(bodyBytes) + if err != nil { + log.Error().Err(err).Msg("Failed to write response body") + return + } + + // -- Secret Mutation Purging -- + + if (r.Method == http.MethodPatch || r.Method == http.MethodDelete || r.Method == http.MethodPost) && + proxy.IsSecretsEndpoint(r.URL.Path) && + resp.StatusCode >= 200 && resp.StatusCode < 300 { + var projectId, environment, secretPath string + + if len(requestBodyBytes) > 0 { + var bodyData map[string]interface{} + if err := json.Unmarshal(requestBodyBytes, &bodyData); err == nil { + // Support both v3 (workspaceId/workspaceSlug) and v4 (projectId) + if projId, ok := bodyData["projectId"].(string); ok { + projectId = projId + } else if workspaceId, ok := bodyData["workspaceId"].(string); ok { + projectId = workspaceId + } else if workspaceSlug, ok := bodyData["workspaceSlug"].(string); ok { + projectId = workspaceSlug + } + if env, ok := bodyData["environment"].(string); ok { + environment = env + } + if path, ok := bodyData["secretPath"].(string); ok { + secretPath = path + } + } else { + log.Error(). + Err(err). + Str("method", r.Method). + Str("path", r.URL.Path). + Msg("Failed to parse mutation request body for cache purging - cache may serve stale data") + } + } + + if secretPath == "" { + secretPath = "/" + } + + if projectId == "" || environment == "" { + log.Warn(). + Str("method", r.Method). + Str("path", r.URL.Path). + Msg("Missing projectId or environment for cache purging - skipping cache purge") + return + } else { + + log.Debug(). + Str("method", r.Method). + Str("path", r.URL.Path). + Str("projectId", projectId). + Str("environment", environment). + Str("secretPath", secretPath). + Msg("Attempting mutation purging across all tokens") + purgedCount := cache.PurgeByMutation(projectId, environment, secretPath) + + if purgedCount == 1 { + log.Info(). + Str("mutationPath", secretPath). + Msg("Entry purged") + } else { + log.Info(). + Int("purgedCount", purgedCount). + Str("mutationPath", secretPath). + Msg("Entries purged") + } + } + } + + // -- Cache Set -- + + if isCacheable && token != "" && resp.StatusCode == http.StatusOK { + cacheKey := proxy.GenerateCacheKey(r.Method, r.URL.Path, r.URL.RawQuery, token) + + queryParams := r.URL.Query() + // Support both v3 (workspaceId/workspaceSlug) and v4 (projectId) + projectId := queryParams.Get("projectId") + if projectId == "" { + projectId = queryParams.Get("workspaceId") + } + if projectId == "" { + projectId = queryParams.Get("workspaceSlug") + } + environment := queryParams.Get("environment") + secretPath := queryParams.Get("secretPath") + if secretPath == "" { + secretPath = "/" + } + + if r.URL.Path == "/api/v3/secrets" || r.URL.Path == "/api/v4/secrets" || + r.URL.Path == "/api/v3/secrets/raw" || r.URL.Path == "/api/v4/secrets/raw" { + recursive := queryParams.Get("recursive") + if recursive == "true" { + secretPath = secretPath + "*" + } + } + + indexEntry := proxy.IndexEntry{ + CacheKey: cacheKey, + SecretPath: secretPath, + EnvironmentSlug: environment, + ProjectId: projectId, + } + + cachedResp := &http.Response{ + StatusCode: resp.StatusCode, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader(bodyBytes)), + } + + proxy.CopyHeaders(cachedResp.Header, resp.Header) + + if indexEntry.ProjectId != "" && indexEntry.EnvironmentSlug != "" { + + cache.Set(cacheKey, r, cachedResp, token, indexEntry) + + log.Debug(). + Str("method", r.Method). + Str("path", r.URL.Path). + Str("cacheKey", cacheKey). + Msg("Secret response cached successfully") + } else { + log.Warn(). + Str("method", r.Method). + Str("path", r.URL.Path). + Str("cacheKey", cacheKey). + Msg("Secret response not cached because project ID or environment slug is empty") + } + } + + log.Debug(). + Str("method", r.Method). + Str("path", r.URL.Path). + Int("status", resp.StatusCode). + Msg("Request forwarded successfully") + } + + // Add proxy handler to mux + mux.HandleFunc("/", proxyHandler) + + var tlsConfig *tls.Config + if tlsEnabled { + cert, err := tls.LoadX509KeyPair(tlsCertFile, tlsKeyFile) + if err != nil { + util.HandleError(err, fmt.Sprintf("Failed to load TLS certificate and key: %s", err)) + } + + tlsConfig = &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + } + + server := &http.Server{ + Addr: listenAddress, + Handler: mux, + TLSConfig: tlsConfig, + } + + resyncCtx, resyncCancel := context.WithCancel(context.Background()) + defer resyncCancel() + + go proxy.StartBackgroundLoops(resyncCtx, cache, domainURL, httpClient, evictionStrategy, accessTokenCheckInterval, staticSecretsRefreshInterval) + + // Handle graceful shutdown + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + + go func() { + sig := <-sigCh + log.Info().Msgf("Received signal %v, shutting down proxy server...", sig) + + // Cancel resync goroutine + resyncCancel() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + if err := server.Shutdown(ctx); err != nil { + log.Error().Err(err).Msg("Error during server shutdown") + os.Exit(1) + } + + log.Info().Msg("Proxy server shutdown complete") + os.Exit(0) + }() + + if tlsEnabled { + log.Info().Msgf("Infisical proxy server starting on %s with TLS enabled", listenAddress) + } else { + log.Info().Msgf("Infisical proxy server starting on %s", listenAddress) + } + + log.Info().Msgf("Forwarding requests to %s", domain) + + if tlsEnabled { + if err := server.ListenAndServeTLS("", ""); err != nil && err != http.ErrServerClosed { + util.HandleError(err, "Failed to start proxy server with TLS") + } + } else { + if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + util.HandleError(err, "Failed to start proxy server") + } + } + +} + +func printCacheDebug(cmd *cobra.Command, args []string) { + if util.CLI_VERSION != "devel" { + util.PrintErrorMessageAndExit("This command is only available in dev mode (when CLI_VERSION is 'devel').") + } + + listenAddress, err := cmd.Flags().GetString("listen-address") + if err != nil { + util.HandleError(err, "Unable to parse listen-address flag") + } + + if listenAddress == "" { + util.PrintErrorMessageAndExit("Listen-address flag is required") + } + + baseURL := "http://" + listenAddress + if strings.HasPrefix(listenAddress, ":") { + baseURL = "http://localhost" + listenAddress + } + + debugURL := baseURL + "/_debug/cache" + resp, err := http.Get(debugURL) + if err != nil { + util.HandleError(err, fmt.Sprintf("Failed to connect to proxy at %s. Make sure the proxy is running in dev mode (CLI_VERSION='devel')", listenAddress)) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + util.PrintErrorMessageAndExit(fmt.Sprintf("Failed to get cache debug info: %s", string(body))) + } + + var debugInfo proxy.CacheDebugInfo + if err := json.NewDecoder(resp.Body).Decode(&debugInfo); err != nil { + util.HandleError(err, "Failed to decode cache debug info") + } + + output, err := json.MarshalIndent(debugInfo, "", " ") + if err != nil { + util.HandleError(err, "Failed to marshal cache debug info") + } + + fmt.Println("Cache Debug Information:") + fmt.Println(string(output)) +} + +func isStreamingEndpoint(path string) bool { + return strings.HasPrefix(path, "/api/v1/events/") +} + +func init() { + proxyStartCmd.Flags().String("domain", "", "Domain of your Infisical instance (e.g., https://app.infisical.com for cloud, https://my-self-hosted-instance.com for self-hosted)") + proxyStartCmd.Flags().String("listen-address", "localhost:8081", "The address for the proxy server to listen on. Defaults to localhost:8081") + proxyStartCmd.Flags().String("eviction-strategy", string(CacheEvictionStrategyOptimistic), "Cache eviction strategy. 'optimistic' keeps cached data when Infisical is unreachable for high availability. Currently only 'optimistic' is supported.") + proxyStartCmd.Flags().String("access-token-check-interval", "5m", "How often to validate that access tokens are still valid (e.g., 5m, 1h). Defaults to 5m.") + proxyStartCmd.Flags().String("static-secrets-refresh-interval", "1h", "How often to refresh cached secrets (e.g., 30m, 1h, 1d). Defaults to 1h.") + proxyStartCmd.Flags().String("tls-cert-file", "", "The path to the TLS certificate file for the proxy server. Required when `tls-enabled` is set to true (default)") + proxyStartCmd.Flags().String("tls-key-file", "", "The path to the TLS key file for the proxy server. Required when `tls-enabled` is set to true (default)") + proxyStartCmd.Flags().Bool("tls-enabled", true, "Whether to enable TLS for the proxy server. Defaults to true") + + proxyDebugCmd.Flags().String("listen-address", "localhost:8081", "The address where the proxy server is listening. Defaults to localhost:8081") + + proxyCmd.AddCommand(proxyStartCmd) + proxyCmd.AddCommand(proxyDebugCmd) + rootCmd.AddCommand(proxyCmd) +} diff --git a/packages/proxy/cache.go b/packages/proxy/cache.go new file mode 100644 index 00000000..5a048fd8 --- /dev/null +++ b/packages/proxy/cache.go @@ -0,0 +1,757 @@ +package proxy + +import ( + "bytes" + "crypto/sha256" + "encoding/hex" + "fmt" + "io" + "net/http" + "strings" + "sync" + "time" + + "github.com/Infisical/infisical-merge/packages/util/cache" + "github.com/rs/zerolog/log" +) + +// Storage key prefixes +const ( + prefixEntry = "entry:" + prefixToken = "token:" + prefixPath = "path:" +) + +type IndexEntry struct { + CacheKey string `json:"cacheKey"` + SecretPath string `json:"secretPath"` + EnvironmentSlug string `json:"environmentSlug"` + ProjectId string `json:"projectId"` +} + +type CachedRequest struct { + Method string `json:"method"` + RequestURI string `json:"requestUri"` + Headers http.Header `json:"headers"` + CachedAt time.Time `json:"cachedAt"` +} + +type CachedResponse struct { + StatusCode int `json:"statusCode"` + Header http.Header `json:"header"` + BodyBytes []byte `json:"bodyBytes"` +} + +// StoredCacheEntry is the structure stored in EncryptedStorage +type StoredCacheEntry struct { + Request *CachedRequest `json:"request"` + Response *CachedResponse `json:"response"` + Token string `json:"token"` + Index IndexEntry `json:"index"` +} + +// PathIndexMarker is a simple marker stored at path index keys +type PathIndexMarker struct { + CacheKey string `json:"cacheKey"` +} + +// Cache is an HTTP response cache fully backed by EncryptedStorage +type Cache struct { + storage *cache.EncryptedStorage + mu sync.RWMutex +} + +// NewCache creates a cache with the specified options +func NewCache(opts cache.EncryptedStorageOptions) (*Cache, error) { + storage, err := cache.NewEncryptedStorage(opts) + if err != nil { + return nil, fmt.Errorf("failed to create cache storage: %w", err) + } + + return &Cache{ + storage: storage, + }, nil +} + +// Close closes the underlying storage +func (c *Cache) Close() error { + return c.storage.Close() +} + +// hashToken creates a short hash of the token for use in storage keys +// This avoids storing the full token in key names while still being unique +func hashToken(token string) string { + hash := sha256.Sum256([]byte(token)) + return hex.EncodeToString(hash[:8]) // First 8 bytes = 16 hex chars +} + +// buildEntryKey builds the storage key for a cache entry +func buildEntryKey(cacheKey string) string { + return prefixEntry + cacheKey +} + +// buildTokenIndexKey builds the storage key for token index entry +func buildTokenIndexKey(token, cacheKey string) string { + return prefixToken + hashToken(token) + ":" + cacheKey +} + +// buildTokenIndexPrefix builds the prefix for all token index entries for a token +func buildTokenIndexPrefix(token string) string { + return prefixToken + hashToken(token) + ":" +} + +// buildPathIndexKey builds the storage key for path index entry +// Key format: path:{projectId}:{envSlug}:{tokenHash}:{escapedSecretPath}:{cacheKey} +func buildPathIndexKey(token string, indexEntry IndexEntry) string { + // Escape colons in secretPath to avoid key parsing issues. + // Currently not relevant as we don't support colons in secret paths, but if we decide to broaden our allowed folder naming in the future, this would be needed + escapedPath := strings.ReplaceAll(indexEntry.SecretPath, ":", "\\:") + key := fmt.Sprintf("%s%s:%s:%s:%s:%s", + prefixPath, + indexEntry.ProjectId, + indexEntry.EnvironmentSlug, + hashToken(token), + escapedPath, + indexEntry.CacheKey, + ) + + log.Debug().Str("pathIndexKey", key).Msg("Built path index key") + + return key +} + +// buildPathIndexPrefixForProject builds the prefix for all path entries matching a project+env +func buildPathIndexPrefixForProject(projectId, envSlug string) string { + return fmt.Sprintf("%s%s:%s:", prefixPath, projectId, envSlug) +} + +func IsSecretsEndpoint(path string) bool { + return (strings.HasPrefix(path, "/api/v3/secrets/") || strings.HasPrefix(path, "/api/v4/secrets/")) || + path == "/api/v3/secrets" || path == "/api/v4/secrets" +} + +func IsCacheableRequest(path string, method string) bool { + if method != http.MethodGet { + return false + } + + return IsSecretsEndpoint(path) +} + +func (c *Cache) Get(cacheKey string) (*http.Response, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + + var entry StoredCacheEntry + err := c.storage.Get(buildEntryKey(cacheKey), &entry) + if err != nil { + return nil, false + } + + if entry.Response == nil { + return nil, false + } + + resp := &http.Response{ + StatusCode: entry.Response.StatusCode, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader(entry.Response.BodyBytes)), + } + + CopyHeaders(resp.Header, entry.Response.Header) + + return resp, true +} + +func (c *Cache) Set(cacheKey string, req *http.Request, resp *http.Response, token string, indexEntry IndexEntry) { + c.mu.Lock() + defer c.mu.Unlock() + + // Read response body + var bodyBytes []byte + if resp.Body != nil { + var err error + bodyBytes, err = io.ReadAll(resp.Body) + if err != nil { + log.Error().Err(err).Str("cacheKey", cacheKey).Msg("Failed to read response body") + bodyBytes = nil + } + } + + // Extract request metadata + requestURI := req.URL.RequestURI() + requestHeaders := make(http.Header) + CopyHeaders(requestHeaders, req.Header) + + // Extract response data + responseHeader := make(http.Header) + CopyHeaders(responseHeader, resp.Header) + + entry := StoredCacheEntry{ + Request: &CachedRequest{ + Method: req.Method, + RequestURI: requestURI, + Headers: requestHeaders, + CachedAt: time.Now(), + }, + Response: &CachedResponse{ + StatusCode: resp.StatusCode, + Header: responseHeader, + BodyBytes: bodyBytes, + }, + Token: token, + Index: indexEntry, + } + + // Store main entry + if err := c.storage.Set(buildEntryKey(cacheKey), entry); err != nil { + log.Error().Err(err).Str("cacheKey", cacheKey).Msg("Failed to store cache entry") + return + } + + // Store token index entry + tokenIndexKey := buildTokenIndexKey(token, cacheKey) + if err := c.storage.Set(tokenIndexKey, indexEntry); err != nil { + log.Error().Err(err).Str("cacheKey", cacheKey).Msg("Failed to store token index entry") + } + + // Store path index entry + pathIndexKey := buildPathIndexKey(token, indexEntry) + if err := c.storage.Set(pathIndexKey, PathIndexMarker{CacheKey: cacheKey}); err != nil { + log.Error().Err(err).Str("cacheKey", cacheKey).Msg("Failed to store path index entry") + } +} + +// UpdateResponse updates only the response data and cachedAt timestamp for an existing cache entry +func (c *Cache) UpdateResponse(cacheKey string, statusCode int, header http.Header, bodyBytes []byte) { + c.mu.Lock() + defer c.mu.Unlock() + + var entry StoredCacheEntry + err := c.storage.Get(buildEntryKey(cacheKey), &entry) + if err != nil { + return + } + + // Deep copy response header + responseHeader := make(http.Header) + CopyHeaders(responseHeader, header) + + // Deep copy bodyBytes + bodyBytesCopy := make([]byte, len(bodyBytes)) + copy(bodyBytesCopy, bodyBytes) + + entry.Response.StatusCode = statusCode + entry.Response.Header = responseHeader + entry.Response.BodyBytes = bodyBytesCopy + entry.Request.CachedAt = time.Now() + + // Update in storage + if err := c.storage.Set(buildEntryKey(cacheKey), entry); err != nil { + log.Error().Err(err).Str("cacheKey", cacheKey).Msg("Failed to update cache entry") + } +} + +func CopyHeaders(dst, src http.Header) { + for key, values := range src { + for _, value := range values { + dst.Add(key, value) + } + } +} + +func ExtractTokenFromRequest(r *http.Request) string { + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + return "" + } + + parts := strings.SplitN(authHeader, " ", 2) + if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" { + return "" + } + + return parts[1] +} + +// GenerateCacheKey generates a cache key for a request by hashing the method, path, query, and token +func GenerateCacheKey(method, path, query, token string) string { + data := method + path + query + token + hash := sha256.Sum256([]byte(data)) + return hex.EncodeToString(hash[:]) +} + +func matchesPath(storedPath, queryPath string) bool { + if strings.HasSuffix(storedPath, "/*") { + base := strings.TrimSuffix(storedPath, "/*") + + if queryPath == base { + return true + } + + // Check if queryPath is under base (e.g., base="/test", queryPath="/test/sub") + return strings.HasPrefix(queryPath+"/", base+"/") + } + + if storedPath == queryPath { + return true + } + + return false +} + +// GetExpiredRequests returns only expired request data for resync +func (c *Cache) GetExpiredRequests(cacheTTL time.Duration) map[string]*CachedRequest { + c.mu.RLock() + defer c.mu.RUnlock() + + now := time.Now() + requests := make(map[string]*CachedRequest) + + // Get all entry keys + entryKeys, err := c.storage.GetKeysByPrefix(prefixEntry) + if err != nil { + log.Error().Err(err).Msg("Failed to get entry keys for expired requests check") + return requests + } + + for _, key := range entryKeys { + var entry StoredCacheEntry + if err := c.storage.Get(key, &entry); err != nil { + continue + } + + if entry.Request == nil { + continue + } + + // Only include entries where cache-ttl has expired + age := now.Sub(entry.Request.CachedAt) + if age <= cacheTTL { + continue + } + + // Extract cacheKey from storage key (remove prefix) + cacheKey := strings.TrimPrefix(key, prefixEntry) + + requestCopy := &CachedRequest{ + Method: entry.Request.Method, + RequestURI: entry.Request.RequestURI, + Headers: make(http.Header), + CachedAt: entry.Request.CachedAt, + } + + CopyHeaders(requestCopy.Headers, entry.Request.Headers) + + requests[cacheKey] = requestCopy + } + + return requests +} + +func (c *Cache) EvictEntry(cacheKey string) { + c.mu.Lock() + defer c.mu.Unlock() + + c.evictEntryUnsafe(cacheKey) +} + +// evictEntryUnsafe evicts an entry without acquiring the lock (caller must hold lock) +func (c *Cache) evictEntryUnsafe(cacheKey string) { + // Get the entry to find its token and index info + var entry StoredCacheEntry + if err := c.storage.Get(buildEntryKey(cacheKey), &entry); err != nil { + return + } + + // Remove main entry + if err := c.storage.Delete(buildEntryKey(cacheKey)); err != nil { + log.Error().Err(err).Str("cacheKey", cacheKey).Msg("Failed to delete cache entry") + } + + // Remove token index entry + tokenIndexKey := buildTokenIndexKey(entry.Token, cacheKey) + if err := c.storage.Delete(tokenIndexKey); err != nil { + log.Debug().Err(err).Str("cacheKey", cacheKey).Msg("Failed to delete token index entry") + } + + // Remove path index entry + pathIndexKey := buildPathIndexKey(entry.Token, entry.Index) + if err := c.storage.Delete(pathIndexKey); err != nil { + log.Debug().Err(err).Str("cacheKey", cacheKey).Msg("Failed to delete path index entry") + } +} + +// GetAllTokens returns all unique tokens that have cached entries +func (c *Cache) GetAllTokens() []string { + c.mu.RLock() + defer c.mu.RUnlock() + + // Get all token index keys and extract unique token hashes + tokenKeys, err := c.storage.GetKeysByPrefix(prefixToken) + if err != nil { + log.Error().Err(err).Msg("Failed to get token index keys") + return nil + } + + // We need to get unique tokens, but we only have hashes in the keys + // We need to look up the actual token from entries + tokenHashToToken := make(map[string]string) + + for _, key := range tokenKeys { + // Key format: token:{tokenHash}:{cacheKey} + parts := strings.SplitN(strings.TrimPrefix(key, prefixToken), ":", 2) + if len(parts) < 2 { + continue + } + tokenHash := parts[0] + cacheKey := parts[1] + + if _, exists := tokenHashToToken[tokenHash]; exists { + continue // Already found this token + } + + // Get the entry to find the actual token + var entry StoredCacheEntry + if err := c.storage.Get(buildEntryKey(cacheKey), &entry); err == nil { + tokenHashToToken[tokenHash] = entry.Token + } + } + + tokens := make([]string, 0, len(tokenHashToToken)) + for _, token := range tokenHashToToken { + tokens = append(tokens, token) + } + + return tokens +} + +// GetFirstRequestForToken gets the first request (any, regardless of expiration) for a token +func (c *Cache) GetFirstRequestForToken(token string) (cacheKey string, request *CachedRequest, found bool) { + c.mu.Lock() + defer c.mu.Unlock() + + tokenPrefix := buildTokenIndexPrefix(token) + tokenKeys, err := c.storage.GetKeysByPrefix(tokenPrefix) + if err != nil || len(tokenKeys) == 0 { + return "", nil, false + } + + // Get the first cacheKey from the token's entries + for _, key := range tokenKeys { + // Key format: token:{tokenHash}:{cacheKey} + parts := strings.SplitN(strings.TrimPrefix(key, prefixToken), ":", 2) + if len(parts) < 2 { + continue + } + cacheKey := parts[1] + + var entry StoredCacheEntry + if err := c.storage.Get(buildEntryKey(cacheKey), &entry); err != nil { + // Delete orphan index entry + c.storage.Delete(key) + continue + } + + if entry.Request == nil { + c.storage.Delete(key) + continue + } + + requestCopy := &CachedRequest{ + Method: entry.Request.Method, + RequestURI: entry.Request.RequestURI, + Headers: make(http.Header), + CachedAt: entry.Request.CachedAt, + } + + CopyHeaders(requestCopy.Headers, entry.Request.Headers) + + return cacheKey, requestCopy, true + } + + return "", nil, false +} + +// EvictAllEntriesForToken evicts all cache entries for a given token +func (c *Cache) EvictAllEntriesForToken(token string) int { + c.mu.Lock() + defer c.mu.Unlock() + + tokenPrefix := buildTokenIndexPrefix(token) + tokenKeys, err := c.storage.GetKeysByPrefix(tokenPrefix) + if err != nil { + return 0 + } + + evictedCount := 0 + + for _, key := range tokenKeys { + // Key format: token:{tokenHash}:{cacheKey} + parts := strings.SplitN(strings.TrimPrefix(key, prefixToken), ":", 2) + if len(parts) < 2 { + continue + } + cacheKey := parts[1] + + c.evictEntryUnsafe(cacheKey) + evictedCount++ + } + + return evictedCount +} + +// RemoveTokenFromIndex removes all index entries for a token (without deleting main entries) +// This is a cleanup function called rarely for orphaned tokens +func (c *Cache) RemoveTokenFromIndex(token string) { + c.mu.Lock() + defer c.mu.Unlock() + + tokenPrefix := buildTokenIndexPrefix(token) + c.storage.DeleteByPrefix(tokenPrefix) + + // Also delete path index entries for this token + // since path keys are prefixed by projectId:envSlug + // we need to scan all path keys to find those containing this token's hash + tokenHash := hashToken(token) + pathKeys, err := c.storage.GetKeysByPrefix(prefixPath) + if err != nil { + log.Debug().Err(err).Msg("Failed to get path keys for token index cleanup") + return + } + + for _, key := range pathKeys { + // Key format: path:{projectId}:{envSlug}:{tokenHash}:{secretPath}:{cacheKey} + withoutPrefix := strings.TrimPrefix(key, prefixPath) + parts := strings.SplitN(withoutPrefix, ":", 4) + if len(parts) < 3 { + continue + } + keyTokenHash := parts[2] + if keyTokenHash == tokenHash { + c.storage.Delete(key) + } + } +} + +// PurgeByMutation purges cache entries across ALL tokens that match the mutation path +func (c *Cache) PurgeByMutation(projectID, envSlug, mutationPath string) int { + c.mu.Lock() + defer c.mu.Unlock() + + purgedCount := 0 + + prefix := buildPathIndexPrefixForProject(projectID, envSlug) + pathKeys, err := c.storage.GetKeysByPrefix(prefix) + if err != nil { + log.Error().Err(err).Msg("Failed to get path index keys for mutation purge") + return 0 + } + + for _, key := range pathKeys { + // Key format: path:{projectId}:{envSlug}:{tokenHash}:{escapedSecretPath}:{cacheKey} + // We already filtered by projectId:envSlug via prefix, so extract remaining parts + withoutPrefix := strings.TrimPrefix(key, prefix) + parts := strings.SplitN(withoutPrefix, ":", 3) + if len(parts) < 3 { + continue + } + + // parts[0] = tokenHash (not needed for matching) + keySecretPath := strings.ReplaceAll(parts[1], "\\:", ":") // Unescape colons + keyCacheKey := parts[2] + + if matchesPath(keySecretPath, mutationPath) { + c.evictEntryUnsafe(keyCacheKey) + purgedCount++ + } + } + + return purgedCount +} + +// CompoundPathIndexDebugInfo represents the compound path index structure +type CompoundPathIndexDebugInfo struct { + Token string `json:"token"` + Projects map[string]ProjectDebugInfo `json:"projects"` + TotalPaths int `json:"totalPaths"` + TotalKeys int `json:"totalKeys"` +} + +// ProjectDebugInfo represents project-level debug info +type ProjectDebugInfo struct { + ProjectID string `json:"projectId"` + Environments map[string]EnvironmentDebugInfo `json:"environments"` + TotalPaths int `json:"totalPaths"` + TotalKeys int `json:"totalKeys"` +} + +// EnvironmentDebugInfo represents environment-level debug info +type EnvironmentDebugInfo struct { + EnvironmentSlug string `json:"environmentSlug"` + Paths map[string]PathDebugInfo `json:"paths"` + TotalKeys int `json:"totalKeys"` +} + +// CacheKeyDebugInfo represents a cache key with its timestamp +type CacheKeyDebugInfo struct { + CacheKey string `json:"cacheKey"` + CachedAt time.Time `json:"cachedAt"` +} + +// PathDebugInfo represents path-level debug info +type PathDebugInfo struct { + SecretPath string `json:"secretPath"` + CacheKeys []CacheKeyDebugInfo `json:"cacheKeys"` + KeyCount int `json:"keyCount"` +} + +// CacheDebugInfo contains debug information about the cache +type CacheDebugInfo struct { + TotalEntries int `json:"totalEntries"` + TotalTokens int `json:"totalTokens"` + TotalSizeBytes int64 `json:"totalSizeBytes"` + EntriesByToken map[string]int `json:"entriesByToken"` + CacheKeys []CacheKeyDebugInfo `json:"cacheKeys"` + TokenIndex map[string][]IndexEntry `json:"tokenIndex"` + CompoundPathIndex []CompoundPathIndexDebugInfo `json:"compoundPathIndex"` +} + +// GetDebugInfo returns debug information about the cache (dev mode only) +func (c *Cache) GetDebugInfo() CacheDebugInfo { + c.mu.RLock() + defer c.mu.RUnlock() + + var totalSize int64 + entriesByToken := make(map[string]int) + tokenIndex := make(map[string][]IndexEntry) + cacheKeys := make([]CacheKeyDebugInfo, 0) + totalEntries := 0 + + // Get all entry keys + entryKeys, err := c.storage.GetKeysByPrefix(prefixEntry) + if err != nil { + log.Error().Err(err).Msg("Failed to get entry keys for debug info") + return CacheDebugInfo{} + } + + // Maps for building compound path index debug info + // tokenHash -> projectID -> envSlug -> secretPath -> []CacheKeyDebugInfo + pathIndexData := make(map[string]map[string]map[string]map[string][]CacheKeyDebugInfo) + tokenHashToToken := make(map[string]string) + + for _, key := range entryKeys { + var entry StoredCacheEntry + if err := c.storage.Get(key, &entry); err != nil { + continue + } + + cacheKey := strings.TrimPrefix(key, prefixEntry) + tokenHash := hashToken(entry.Token) + tokenHashToToken[tokenHash] = entry.Token + + // Count entries per token + entriesByToken[entry.Token]++ + + // Add to token index + if tokenIndex[entry.Token] == nil { + tokenIndex[entry.Token] = make([]IndexEntry, 0) + } + tokenIndex[entry.Token] = append(tokenIndex[entry.Token], entry.Index) + + // Calculate size + if entry.Response != nil { + totalSize += int64(len(entry.Response.BodyBytes)) + } + + // Add to cache keys list + if entry.Request != nil { + cacheKeys = append(cacheKeys, CacheKeyDebugInfo{ + CacheKey: cacheKey, + CachedAt: entry.Request.CachedAt, + }) + } + + totalEntries++ + + // Build path index data + if pathIndexData[tokenHash] == nil { + pathIndexData[tokenHash] = make(map[string]map[string]map[string][]CacheKeyDebugInfo) + } + if pathIndexData[tokenHash][entry.Index.ProjectId] == nil { + pathIndexData[tokenHash][entry.Index.ProjectId] = make(map[string]map[string][]CacheKeyDebugInfo) + } + if pathIndexData[tokenHash][entry.Index.ProjectId][entry.Index.EnvironmentSlug] == nil { + pathIndexData[tokenHash][entry.Index.ProjectId][entry.Index.EnvironmentSlug] = make(map[string][]CacheKeyDebugInfo) + } + keyInfo := CacheKeyDebugInfo{CacheKey: cacheKey} + if entry.Request != nil { + keyInfo.CachedAt = entry.Request.CachedAt + } + pathIndexData[tokenHash][entry.Index.ProjectId][entry.Index.EnvironmentSlug][entry.Index.SecretPath] = + append(pathIndexData[tokenHash][entry.Index.ProjectId][entry.Index.EnvironmentSlug][entry.Index.SecretPath], keyInfo) + } + + // Build compound path index debug info + compoundPathIndex := make([]CompoundPathIndexDebugInfo, 0) + for tokenHash, projectMap := range pathIndexData { + token := tokenHashToToken[tokenHash] + projects := make(map[string]ProjectDebugInfo) + totalPaths := 0 + totalKeys := 0 + + for projectID, envMap := range projectMap { + environments := make(map[string]EnvironmentDebugInfo) + projectTotalPaths := 0 + projectTotalKeys := 0 + + for envSlug, pathsMap := range envMap { + paths := make(map[string]PathDebugInfo) + envTotalKeys := 0 + + for secretPath, keyInfos := range pathsMap { + paths[secretPath] = PathDebugInfo{ + SecretPath: secretPath, + CacheKeys: keyInfos, + KeyCount: len(keyInfos), + } + envTotalKeys += len(keyInfos) + projectTotalPaths++ + } + + environments[envSlug] = EnvironmentDebugInfo{ + EnvironmentSlug: envSlug, + Paths: paths, + TotalKeys: envTotalKeys, + } + projectTotalKeys += envTotalKeys + } + + projects[projectID] = ProjectDebugInfo{ + ProjectID: projectID, + Environments: environments, + TotalPaths: projectTotalPaths, + TotalKeys: projectTotalKeys, + } + totalPaths += projectTotalPaths + totalKeys += projectTotalKeys + } + + compoundPathIndex = append(compoundPathIndex, CompoundPathIndexDebugInfo{ + Token: token, + Projects: projects, + TotalPaths: totalPaths, + TotalKeys: totalKeys, + }) + } + + return CacheDebugInfo{ + TotalEntries: totalEntries, + TotalTokens: len(tokenHashToToken), + TotalSizeBytes: totalSize, + EntriesByToken: entriesByToken, + CacheKeys: cacheKeys, + TokenIndex: tokenIndex, + CompoundPathIndex: compoundPathIndex, + } +} diff --git a/packages/proxy/resync.go b/packages/proxy/resync.go new file mode 100644 index 00000000..f8d989b0 --- /dev/null +++ b/packages/proxy/resync.go @@ -0,0 +1,316 @@ +package proxy + +import ( + "context" + "encoding/json" + "io" + "math/rand" + "net/http" + "net/url" + "regexp" + "sort" + "strconv" + "time" + + "github.com/rs/zerolog/log" +) + +var rateLimitSecondsRegex = regexp.MustCompile(`(\d+)\s+seconds?`) + +// maskToken masks a token showing only first 5 and last 5 characters +func maskToken(token string) string { + if len(token) <= 10 { + return "***" + } + return token[:5] + "..." + token[len(token)-5:] +} + +// parseRateLimitSeconds extracts retry-after seconds from rate limit error message +// Expected format: "Rate limit exceeded. Please try again in 57 seconds" +// Returns default of 10 seconds if parsing fails +func parseRateLimitSeconds(body []byte) int { + var errorResponse struct { + Message string `json:"message"` + } + + var seconds int = 10 + + if err := json.Unmarshal(body, &errorResponse); err != nil { + return seconds + } + + matches := rateLimitSecondsRegex.FindStringSubmatch(errorResponse.Message) + if len(matches) < 2 { + return 10 + } + + seconds, err := strconv.Atoi(matches[1]) + if err != nil { + return 10 + } + + return seconds +} + +func handleResyncResponse(cache *Cache, cacheKey string, requestURI string, resp *http.Response) (refetched bool, evicted bool, rateLimited bool, retryAfterSeconds int) { + defer resp.Body.Close() + + switch resp.StatusCode { + case http.StatusOK: + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + log.Error(). + Err(err). + Str("cacheKey", cacheKey). + Msg("Failed to read response body during resync") + return false, false, false, 0 + } + + // Update only response data (IndexEntry doesn't change during resync) + cache.UpdateResponse(cacheKey, resp.StatusCode, resp.Header, bodyBytes) + + log.Debug(). + Str("cacheKey", cacheKey). + Str("requestURI", requestURI). + Msg("Successfully refetched and updated cache entry") + return true, false, false, 0 + case http.StatusUnauthorized, http.StatusForbidden, http.StatusNotFound: + // Evict entry on 401/403/404 + cache.EvictEntry(cacheKey) + + log.Info(). + Str("hash", cacheKey). + Msg("Entry evicted") + return false, true, false, 0 + case http.StatusTooManyRequests: + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + log.Error(). + Err(err). + Str("cacheKey", cacheKey). + Msg("Failed to read rate limit response body, using default 10 seconds") + return false, false, true, 10 + } + + retryAfter := parseRateLimitSeconds(bodyBytes) + + log.Debug(). + Str("cacheKey", cacheKey). + Str("requestURI", requestURI). + Int("retryAfterSeconds", retryAfter). + Msg("Rate limited during resync") + return false, false, true, retryAfter + default: + // Other error status codes - keep stale entry + log.Debug(). + Str("cacheKey", cacheKey). + Str("requestURI", requestURI). + Int("statusCode", resp.StatusCode). + Msg("Unexpected status code during resync - keeping stale entry") + return false, false, false, 0 + } +} + +func reconstructProxyRequest(domainURL *url.URL, request *CachedRequest) (*http.Request, error) { + targetURL := *domainURL + parsedURI, err := url.Parse(request.RequestURI) + if err != nil { + return nil, err + } + + targetURL.Path = domainURL.Path + parsedURI.Path + targetURL.RawQuery = parsedURI.RawQuery + + proxyReq, err := http.NewRequest(request.Method, targetURL.String(), nil) + if err != nil { + return nil, err + } + + CopyHeaders(proxyReq.Header, request.Headers) + return proxyReq, nil +} + +// runAccessTokenValidation validates all cached tokens and evicts entries for invalid tokens +func runAccessTokenValidation(cache *Cache, domainURL *url.URL, httpClient *http.Client) { + log.Info().Msg("Starting access token validation") + + tokens := cache.GetAllTokens() + tokensEvicted := 0 + + for _, token := range tokens { + // Add jitter to avoid bursts + time.Sleep(time.Duration(rand.Intn(500)) * time.Millisecond) + + cacheKey, request, found := cache.GetFirstRequestForToken(token) + if !found { + cache.RemoveTokenFromIndex(token) + log.Debug(). + Str("token", maskToken(token)). + Msg("Removed orphaned token entry during token validation") + continue + } + + proxyReq, err := reconstructProxyRequest(domainURL, request) + if err != nil { + log.Error(). + Err(err). + Str("token", maskToken(token)). + Str("cacheKey", cacheKey). + Str("requestURI", request.RequestURI). + Msg("Failed to reconstruct request during token validation") + continue + } + + resp, err := httpClient.Do(proxyReq) + if err != nil || (resp != nil && resp.StatusCode >= 500) { + // Keep entries for high availability (optimistic eviction strategy) + if resp != nil { + resp.Body.Close() + } + log.Error(). + Err(err). + Str("token", maskToken(token)). + Str("cacheKey", cacheKey). + Str("requestURI", request.RequestURI). + Msg("Network error during token validation - keeping entries (optimistic strategy)") + continue + } + + // If 401, evict all entries for this token + if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden { + evictedCount := cache.EvictAllEntriesForToken(token) + resp.Body.Close() + tokensEvicted++ + + if evictedCount == 1 { + log.Info(). + Str("token", maskToken(token)). + Msg("Token invalid - entry evicted") + } else { + log.Info(). + Int("evictedCount", evictedCount). + Str("token", maskToken(token)). + Msg("Token invalid - entries evicted") + } + } else { + resp.Body.Close() + } + } + + log.Info(). + Int("tokensChecked", len(tokens)). + Int("tokensEvicted", tokensEvicted). + Msg("Access token validation completed") +} + +// runStaticSecretsRefresh refreshes all cached secrets that have exceeded the refresh interval +func runStaticSecretsRefresh(cache *Cache, domainURL *url.URL, httpClient *http.Client, refreshInterval time.Duration) { + log.Info().Msg("Starting static secrets refresh") + + cycleStartTime := time.Now() + + requests := cache.GetExpiredRequests(refreshInterval) + + // Convert map to slice and sort by CachedAt (oldest first) + type orderedEntry struct { + cacheKey string + request *CachedRequest + } + ordered := make([]orderedEntry, 0, len(requests)) + for key, req := range requests { + ordered = append(ordered, orderedEntry{key, req}) + } + sort.Slice(ordered, func(i, j int) bool { + return ordered[i].request.CachedAt.Before(ordered[j].request.CachedAt) + }) + + refetched := 0 + evicted := 0 + + for _, entry := range ordered { + // Add jitter to avoid bursts + time.Sleep(time.Duration(rand.Intn(500)) * time.Millisecond) + + proxyReq, err := reconstructProxyRequest(domainURL, entry.request) + if err != nil { + log.Error(). + Err(err). + Str("cacheKey", entry.cacheKey). + Str("requestURI", entry.request.RequestURI). + Msg("Failed to parse requestURI during secrets refresh") + continue + } + + resp, err := httpClient.Do(proxyReq) + if err != nil || (resp != nil && resp.StatusCode >= 500) { + // Keep stale entry for high availability (optimistic eviction strategy) + if resp != nil { + resp.Body.Close() + } + log.Error(). + Err(err). + Str("cacheKey", entry.cacheKey). + Str("requestURI", entry.request.RequestURI). + Msg("Network error during secrets refresh - keeping stale entry (optimistic strategy)") + continue + } + + refetchedResult, evictedResult, rateLimited, retryAfterSeconds := handleResyncResponse(cache, entry.cacheKey, entry.request.RequestURI, resp) + if refetchedResult { + refetched++ + } + if evictedResult { + evicted++ + } + + // Handle rate limiting + if rateLimited { + pauseDuration := time.Duration(retryAfterSeconds+2) * time.Second // 2 seconds buffer + timeUntilNextTick := refreshInterval - time.Since(cycleStartTime) + + if pauseDuration <= timeUntilNextTick { + log.Info(). + Int("pauseSeconds", retryAfterSeconds+2). + Msg("Rate limited, pausing secrets refresh") + time.Sleep(pauseDuration) + } else { + log.Warn(). + Int("pauseSeconds", retryAfterSeconds+2). + Msg("Rate limit pause exceeds refresh interval, remaining entries will be processed next cycle. Increase the static-secrets-refresh-interval value to prevent this behavior.") + break + } + } + } + + log.Info(). + Int("expiredEntries", len(requests)). + Int("refetched", refetched). + Int("evicted", evicted). + Msg("Static secrets refresh completed") +} + +// StartBackgroundLoops starts the background loops for token validation and secrets refresh +func StartBackgroundLoops(ctx context.Context, cache *Cache, domainURL *url.URL, httpClient *http.Client, evictionStrategy string, accessTokenCheckInterval time.Duration, staticSecretsRefreshInterval time.Duration) { + tokenTicker := time.NewTicker(accessTokenCheckInterval) + secretsTicker := time.NewTicker(staticSecretsRefreshInterval) + defer tokenTicker.Stop() + defer secretsTicker.Stop() + + log.Info(). + Str("evictionStrategy", evictionStrategy). + Str("accessTokenCheckInterval", accessTokenCheckInterval.String()). + Str("staticSecretsRefreshInterval", staticSecretsRefreshInterval.String()). + Msg("Background loops started") + + for { + select { + case <-tokenTicker.C: + runAccessTokenValidation(cache, domainURL, httpClient) + case <-secretsTicker.C: + runStaticSecretsRefresh(cache, domainURL, httpClient, staticSecretsRefreshInterval) + case <-ctx.Done(): + log.Info().Msg("Background loops stopped") + return + } + } +} diff --git a/packages/util/agent.go b/packages/util/agent.go deleted file mode 100644 index 585ab9ed..00000000 --- a/packages/util/agent.go +++ /dev/null @@ -1,58 +0,0 @@ -package util - -import ( - "fmt" - "strconv" - "time" -) - -// ParseTimeDurationString converts a string representation of a polling interval to a time.Duration -func ParseTimeDurationString(pollingInterval string, allowLessThanOneSecond bool) (time.Duration, error) { - length := len(pollingInterval) - if length < 2 { - return 0, fmt.Errorf("invalid format") - } - - splitIndex := length - for i := length - 1; i >= 0; i-- { - if pollingInterval[i] >= '0' && pollingInterval[i] <= '9' { - splitIndex = i + 1 - break - } - } - - if splitIndex == 0 || splitIndex == length { - return 0, fmt.Errorf("invalid format: must contain both number and unit") - } - - numberPart := pollingInterval[:splitIndex] - unit := pollingInterval[splitIndex:] - - number, err := strconv.Atoi(numberPart) - if err != nil { - return 0, err - } - - switch unit { - case "s": - if number < 60 && !IsDevelopmentMode() && !allowLessThanOneSecond { - return 0, fmt.Errorf("polling interval must be at least 60 seconds") - } - return time.Duration(number) * time.Second, nil - case "ms": - if number < 1000 && !IsDevelopmentMode() && !allowLessThanOneSecond { - return 0, fmt.Errorf("polling interval must be at least 1000 milliseconds") - } - return time.Duration(number) * time.Millisecond, nil - case "m": - return time.Duration(number) * time.Minute, nil - case "h": - return time.Duration(number) * time.Hour, nil - case "d": - return time.Duration(number) * 24 * time.Hour, nil - case "w": - return time.Duration(number) * 7 * 24 * time.Hour, nil - default: - return 0, fmt.Errorf("invalid time unit") - } -} diff --git a/packages/util/cache/cache-storage.go b/packages/util/cache/cache-storage.go index 1dee11f9..44a7b263 100644 --- a/packages/util/cache/cache-storage.go +++ b/packages/util/cache/cache-storage.go @@ -11,13 +11,14 @@ import ( "reflect" "time" + "github.com/awnumar/memguard" "github.com/dgraph-io/badger/v3" "github.com/rs/zerolog/log" ) type EncryptedStorage struct { db *badger.DB - key [32]byte + key *memguard.LockedBuffer } type EncryptedStorageOptions struct { @@ -27,7 +28,7 @@ type EncryptedStorageOptions struct { InMemory bool // Only required if InMemory is false - EncryptionKey [32]byte + EncryptionKey *memguard.LockedBuffer } func NewEncryptedStorage(opts EncryptedStorageOptions) (*EncryptedStorage, error) { @@ -181,7 +182,140 @@ func (s *EncryptedStorage) Delete(key string) error { }) } +// GetKeysByPrefix returns all keys that start with the given prefix (keys only, no values) +func (s *EncryptedStorage) GetKeysByPrefix(prefix string) ([]string, error) { + var keys []string + + err := s.db.View(func(txn *badger.Txn) error { + opts := badger.DefaultIteratorOptions + opts.PrefetchValues = false // Keys only, much faster + it := txn.NewIterator(opts) + defer it.Close() + + prefixBytes := []byte(prefix) + for it.Seek(prefixBytes); it.ValidForPrefix(prefixBytes); it.Next() { + keys = append(keys, string(it.Item().Key())) + } + return nil + }) + + if err != nil { + return nil, fmt.Errorf("failed to get keys by prefix: %w", err) + } + + return keys, nil +} + +// GetByPrefix returns all key-value pairs where the key starts with the given prefix +func (s *EncryptedStorage) GetByPrefix(prefix string, destFactory func() interface{}) (map[string]interface{}, error) { + result := make(map[string]interface{}) + + err := s.db.View(func(txn *badger.Txn) error { + opts := badger.DefaultIteratorOptions + opts.PrefetchSize = 10 + it := txn.NewIterator(opts) + defer it.Close() + + prefixBytes := []byte(prefix) + for it.Seek(prefixBytes); it.ValidForPrefix(prefixBytes); it.Next() { + item := it.Item() + key := string(item.Key()) + + encrypted, err := item.ValueCopy(nil) + if err != nil { + return fmt.Errorf("failed to copy value for key %s: %w", key, err) + } + + decrypted, err := s.decrypt(encrypted) + if err != nil { + return fmt.Errorf("failed to decrypt value for key %s: %w", key, err) + } + + dest := destFactory() + if err := json.Unmarshal(decrypted, dest); err != nil { + return fmt.Errorf("failed to unmarshal value for key %s: %w", key, err) + } + + result[key] = dest + } + return nil + }) + + if err != nil { + return nil, err + } + + return result, nil +} + +// DeleteByPrefix deletes all keys that start with the given prefix +// Deletions are batched to avoid exceeding BadgerDB's transaction size limits +func (s *EncryptedStorage) DeleteByPrefix(prefix string) (int, error) { + const batchSize = 1000 // Process deletions in batches to avoid transaction size limits + + log.Debug().Str("prefix", prefix).Msg("Deleting by prefix") + + // First, collect all keys to delete + keysToDelete, err := s.GetKeysByPrefix(prefix) + if err != nil { + return 0, err + } + + if len(keysToDelete) == 0 { + return 0, nil + } + + deletedCount := 0 + + // Process deletions in batches + for i := 0; i < len(keysToDelete); i += batchSize { + end := i + batchSize + if end > len(keysToDelete) { + end = len(keysToDelete) + } + batch := keysToDelete[i:end] + + err = s.db.Update(func(txn *badger.Txn) error { + for _, key := range batch { + if err := txn.Delete([]byte(key)); err != nil { + return fmt.Errorf("failed to delete key %s: %w", key, err) + } + } + return nil + }) + + if err != nil { + return deletedCount, fmt.Errorf("failed to delete batch starting at index %d: %w", i, err) + } + + deletedCount += len(batch) + } + + return deletedCount, nil +} + +// Exists checks if a key exists in the storage +func (s *EncryptedStorage) Exists(key string) (bool, error) { + var exists bool + + err := s.db.View(func(txn *badger.Txn) error { + _, err := txn.Get([]byte(key)) + if err == badger.ErrKeyNotFound { + exists = false + return nil + } + if err != nil { + return err + } + exists = true + return nil + }) + + return exists, err +} + func (s *EncryptedStorage) Close() error { + s.key.Destroy() return s.db.Close() } @@ -215,7 +349,7 @@ func (s *EncryptedStorage) StartPeriodicGarbageCollection(context context.Contex } func (s *EncryptedStorage) encrypt(plaintext []byte) ([]byte, error) { - block, err := aes.NewCipher(s.key[:]) + block, err := aes.NewCipher(s.key.Bytes()) if err != nil { return nil, err } @@ -234,7 +368,7 @@ func (s *EncryptedStorage) encrypt(plaintext []byte) ([]byte, error) { } func (s *EncryptedStorage) decrypt(ciphertext []byte) ([]byte, error) { - block, err := aes.NewCipher(s.key[:]) + block, err := aes.NewCipher(s.key.Bytes()) if err != nil { return nil, err } diff --git a/packages/util/helper.go b/packages/util/helper.go index 30eebaa1..1e346cce 100644 --- a/packages/util/helper.go +++ b/packages/util/helper.go @@ -13,6 +13,7 @@ import ( "path" "runtime" "sort" + "strconv" "strings" "sync" "time" @@ -623,3 +624,58 @@ func OpenBrowser(url string) error { return cmd.Start() } + +// ParseTimeDurationString converts a string representation of a polling interval to a time.Duration +func ParseTimeDurationString(pollingInterval string, allowLessThanOneSecond bool) (time.Duration, error) { + length := len(pollingInterval) + if length < 2 { + return 0, fmt.Errorf("invalid format") + } + + splitIndex := length + for i := length - 1; i >= 0; i-- { + if pollingInterval[i] >= '0' && pollingInterval[i] <= '9' { + splitIndex = i + 1 + break + } + } + + if splitIndex == 0 || splitIndex == length { + return 0, fmt.Errorf("invalid format: must contain both number and unit") + } + + numberPart := pollingInterval[:splitIndex] + unit := pollingInterval[splitIndex:] + + number, err := strconv.Atoi(numberPart) + if err != nil { + return 0, err + } + + if number <= 0 { + return 0, fmt.Errorf("polling interval must be greater than 0") + } + + switch unit { + case "s": + if number < 60 && !IsDevelopmentMode() && !allowLessThanOneSecond { + return 0, fmt.Errorf("polling interval must be at least 60 seconds") + } + return time.Duration(number) * time.Second, nil + case "ms": + if number < 1000 && !IsDevelopmentMode() && !allowLessThanOneSecond { + return 0, fmt.Errorf("polling interval must be at least 1000 milliseconds") + } + return time.Duration(number) * time.Millisecond, nil + case "m": + return time.Duration(number) * time.Minute, nil + case "h": + return time.Duration(number) * time.Hour, nil + case "d": + return time.Duration(number) * 24 * time.Hour, nil + case "w": + return time.Duration(number) * 7 * 24 * time.Hour, nil + default: + return 0, fmt.Errorf("invalid time unit") + } +}