diff --git a/go.mod b/go.mod index 8996fabb..e7d7099e 100644 --- a/go.mod +++ b/go.mod @@ -32,6 +32,7 @@ require ( github.com/quic-go/quic-go v0.54.1 github.com/rs/cors v1.11.0 github.com/rs/zerolog v1.26.1 + github.com/smallnest/resp3 v0.0.0-20251228151914-4f2fa7427e69 github.com/spf13/cobra v1.6.1 github.com/spf13/viper v1.8.1 github.com/stretchr/testify v1.10.0 @@ -73,7 +74,7 @@ require ( github.com/aws/smithy-go v1.20.2 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/cespare/xxhash v1.1.0 // indirect - github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/chzyer/readline v1.5.1 // indirect github.com/danieljoos/wincred v1.2.0 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect @@ -81,6 +82,7 @@ require ( github.com/dustin/go-humanize v1.0.0 // indirect github.com/dvsekhvalnov/jose2go v1.7.0 // indirect github.com/emicklei/go-restful/v3 v3.11.0 // indirect + github.com/emirpasic/gods v1.12.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fsnotify/fsnotify v1.4.9 // indirect github.com/fxamacker/cbor/v2 v2.7.0 // indirect diff --git a/go.sum b/go.sum index e9f236ba..cb61e01f 100644 --- a/go.sum +++ b/go.sum @@ -111,8 +111,8 @@ github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko= github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= -github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/charmbracelet/lipgloss v0.9.1 h1:PNyd3jvaJbg4jRHKWXnCj1akQm4rh8dbEzN1p/u1KWg= github.com/charmbracelet/lipgloss v0.9.1/go.mod h1:1mPmG4cxScwUQALAAnacHaigiiHB9Pmr+v1VEawJl6I= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= @@ -159,6 +159,8 @@ github.com/dvsekhvalnov/jose2go v1.7.0 h1:bnQc8+GMnidJZA8zc6lLEAb4xNrIqHwO+9Tzqv github.com/dvsekhvalnov/jose2go v1.7.0/go.mod h1:QsHjhyTlD/lAVqn/NSbVZmSCGeDehTB/mPZadG+mhXU= github.com/emicklei/go-restful/v3 v3.11.0 h1:rAQeMHw1c7zTmncogyy8VvRZwtkmkZ4FxERmMY4rD+g= github.com/emicklei/go-restful/v3 v3.11.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= +github.com/emirpasic/gods v1.12.0 h1:QAUIPSaCu4G+POclxeqb3F+WPpdKqFGlw36+yOzGlrg= +github.com/emirpasic/gods v1.12.0/go.mod h1:YfzfFFoVP/catgzJb4IKIqXjX78Ha8FMSDh3ymbK86o= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= @@ -513,6 +515,8 @@ github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k= github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= +github.com/smallnest/resp3 v0.0.0-20251228151914-4f2fa7427e69 h1:AkDv2coi+ZsMlEp/6V21FWxdswSIEzqflgJ6snIQG+U= +github.com/smallnest/resp3 v0.0.0-20251228151914-4f2fa7427e69/go.mod h1:cmfXTZVXEA7xFOYcGnpKp2VeFf6FUHmxdKQHVNE6BXY= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykElWQ6/NYmHa3jpm/yHnI4xSofP+UP6SpjHcSeM= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s= diff --git a/packages/cmd/pam.go b/packages/cmd/pam.go index 34913399..62727e4d 100644 --- a/packages/cmd/pam.go +++ b/packages/cmd/pam.go @@ -164,7 +164,7 @@ var pamKubernetesAccessAccountCmd = &cobra.Command{ Use: "access-account ", Short: "Access Kubernetes PAM account", Long: "Access Kubernetes via a PAM-managed Kubernetes account. This command automatically launches a proxy connected to your Kubernetes cluster through the Infisical Gateway.", - Example: "infisical pam kubernetes access-account prod/ssh/my-k8s-account --duration 2h", + Example: "infisical pam kubernetes access-account prod/ssh/my-k8s-account --duration 2h --project-id ", DisableFlagsInUseLine: true, Args: cobra.ExactArgs(1), Run: func(cmd *cobra.Command, args []string) { @@ -222,6 +222,76 @@ var pamKubernetesAccessAccountCmd = &cobra.Command{ }, } +var pamRedisCmd = &cobra.Command{ + Use: "redis", + Short: "Redis-related PAM commands", + Long: "Redis-related PAM commands for Infisical", + DisableFlagsInUseLine: true, + Args: cobra.NoArgs, +} + +var pamRedisAccessAccountCmd = &cobra.Command{ + Use: "access-account ", + Short: "Access Redis PAM account", + Long: "Access Redis via a PAM-managed Redis account. This starts a local Redis proxy server that you can use to connect to Redis directly.", + Example: "infisical pam redis access-account prod/redis/my-redis-account --duration 4h --port 6379 --project-id ", + DisableFlagsInUseLine: true, + Args: cobra.ExactArgs(1), + Run: func(cmd *cobra.Command, args []string) { + util.RequireLogin() + + accountPath := args[0] + + projectID, err := cmd.Flags().GetString("project-id") + if err != nil { + util.HandleError(err, "Unable to parse project-id flag") + } + + if projectID == "" { + workspaceFile, err := util.GetWorkSpaceFromFile() + if err != nil { + util.PrintErrorMessageAndExit("Please either run infisical init to connect to a project or pass in project id with --project-id flag") + } + projectID = workspaceFile.WorkspaceId + } + + durationStr, err := cmd.Flags().GetString("duration") + if err != nil { + util.HandleError(err, "Unable to parse duration flag") + } + + // Parse duration + _, err = time.ParseDuration(durationStr) + if err != nil { + util.HandleError(err, "Invalid duration format. Use formats like '1h', '30m', '2h30m'") + } + + port, err := cmd.Flags().GetInt("port") + if err != nil { + util.HandleError(err, "Unable to parse port flag") + } + + log.Debug().Msg("PAM Redis Access: Trying to fetch secrets using logged in details") + + loggedInUserDetails, err := util.GetCurrentLoggedInUserDetails(true) + isConnected := util.ValidateInfisicalAPIConnection() + + if isConnected { + log.Debug().Msg("PAM Redis Access: Connected to Infisical instance, checking logged in creds") + } + + if err != nil { + util.HandleError(err, "Unable to get logged in user details") + } + + if isConnected && loggedInUserDetails.LoginExpired { + loggedInUserDetails = util.EstablishUserLoginSession() + } + + pam.StartRedisLocalProxy(loggedInUserDetails.UserCredentials.JTWToken, accountPath, projectID, durationStr, port) + }, +} + func init() { pamDbCmd.AddCommand(pamDbAccessAccountCmd) pamDbAccessAccountCmd.Flags().String("duration", "1h", "Duration for database access session (e.g., '1h', '30m', '2h30m')") @@ -237,8 +307,14 @@ func init() { pamKubernetesAccessAccountCmd.Flags().Int("port", 0, "Port for the local kubernetes proxy server (0 for auto-assign)") pamKubernetesAccessAccountCmd.Flags().String("project-id", "", "Project ID of the account to access") + pamRedisCmd.AddCommand(pamRedisAccessAccountCmd) + pamRedisAccessAccountCmd.Flags().String("duration", "1h", "Duration for Redis access session (e.g., '1h', '30m', '2h30m')") + pamRedisAccessAccountCmd.Flags().Int("port", 0, "Port for the local Redis proxy server (0 for auto-assign)") + pamRedisAccessAccountCmd.Flags().String("project-id", "", "Project ID of the account to access") + pamCmd.AddCommand(pamDbCmd) pamCmd.AddCommand(pamSshCmd) pamCmd.AddCommand(pamKubernetesCmd) + pamCmd.AddCommand(pamRedisCmd) rootCmd.AddCommand(pamCmd) } diff --git a/packages/pam/handlers/mysql/relay_handler.go b/packages/pam/handlers/mysql/relay_handler.go index c3d2affd..e48b0335 100644 --- a/packages/pam/handlers/mysql/relay_handler.go +++ b/packages/pam/handlers/mysql/relay_handler.go @@ -105,13 +105,13 @@ func (r *RelayHandler) checkConnLostError(err error) { } } -func (r *RelayHandler) writeLogEntry(entry session.SessionLogEntry) (*mysql.Result, error) { +func (r *RelayHandler) writeLogEntry(entry session.SessionLogEntry) error { err := r.sessionLogger.LogEntry(entry) if err != nil { log.Error().Err(err).Msg("failed to write log entry to file") - return nil, err + return err } - return nil, nil + return nil } func formatResult(result *mysql.Result) string { diff --git a/packages/pam/handlers/redis/conn.go b/packages/pam/handlers/redis/conn.go new file mode 100644 index 00000000..02f6fb16 --- /dev/null +++ b/packages/pam/handlers/redis/conn.go @@ -0,0 +1,48 @@ +package redis + +import ( + "net" + + "github.com/smallnest/resp3" +) + +type RedisConn struct { + conn net.Conn + reader *resp3.Reader + writer *resp3.Writer +} + +func NewRedisConn(conn net.Conn) *RedisConn { + return &RedisConn{ + conn: conn, + reader: resp3.NewReader(conn), + writer: resp3.NewWriter(conn), + } +} + +func (c *RedisConn) Close() error { + defer func() { _ = c.conn.Close() }() + if err := c.writer.Flush(); err != nil { + return err + } + return nil +} + +func (c *RedisConn) Reader() *resp3.Reader { + return c.reader +} + +func (c *RedisConn) Writer() *resp3.Writer { + return c.writer +} + +func (c *RedisConn) WriteValue(value *resp3.Value, flush bool) error { + _, err := c.writer.WriteString(value.ToRESP3String()) + if err != nil { + return err + } + if !flush { + return nil + } + return c.writer.Flush() +} diff --git a/packages/pam/handlers/redis/proxy.go b/packages/pam/handlers/redis/proxy.go new file mode 100644 index 00000000..6a6c8a4d --- /dev/null +++ b/packages/pam/handlers/redis/proxy.go @@ -0,0 +1,99 @@ +package redis + +import ( + "context" + "crypto/tls" + "fmt" + "net" + + "github.com/Infisical/infisical-merge/packages/pam/session" + "github.com/rs/zerolog/log" + "github.com/smallnest/resp3" +) + +// RedisProxyConfig holds configuration for the Redis proxy +type RedisProxyConfig struct { + TargetAddr string + InjectUsername string + InjectPassword string + EnableTLS bool + TLSConfig *tls.Config + SessionID string + SessionLogger session.SessionLogger +} + +// RedisProxy handles proxying Redis connections +type RedisProxy struct { + config RedisProxyConfig + relayHandler *RelayHandler +} + +// NewRedisProxy creates a new Redis proxy instance +func NewRedisProxy(config RedisProxyConfig) *RedisProxy { + return &RedisProxy{config: config} +} + +// HandleConnection handles a single client connection +func (p *RedisProxy) HandleConnection(ctx context.Context, clientConn net.Conn) error { + defer clientConn.Close() + + sessionID := p.config.SessionID + + // Ensure session logger cleanup + defer func() { + if err := p.config.SessionLogger.Close(); err != nil { + log.Error().Err(err).Str("sessionID", sessionID).Msg("Failed to close session logger") + } + }() + + log.Info(). + Str("sessionID", sessionID). + Msg("New Redis connection for PAM session") + + var selfToServerConn net.Conn + if !p.config.EnableTLS { + c, err := net.Dial("tcp", p.config.TargetAddr) + if err != nil { + return err + } + selfToServerConn = c + } else { + c, err := tls.Dial("tcp", p.config.TargetAddr, p.config.TLSConfig) + if err != nil { + return err + } + selfToServerConn = c + } + + selfToClientRedisConn := NewRedisConn(selfToServerConn) + defer func(selfToClientRedisConn *RedisConn) { _ = selfToClientRedisConn.Close() }(selfToClientRedisConn) + + // Only authenticate if credentials are provided + if p.config.InjectUsername != "" && p.config.InjectPassword != "" { + if err := selfToClientRedisConn.Writer().WriteCommand("AUTH", p.config.InjectUsername, p.config.InjectPassword); err != nil { + return err + } + if err := selfToClientRedisConn.Writer().Flush(); err != nil { + return err + } + + respValue, _, err := selfToClientRedisConn.Reader().ReadValue() + if err != nil { + return err + } + if respValue.Str != "OK" { + errorMsg := "unknown" + if respValue.Type == resp3.TypeSimpleError || respValue.Type == resp3.TypeBlobError { + errorMsg = respValue.Err + } + log.Error().Str("errorMsg", errorMsg).Msg("Failed to authenticate with the target redis server") + return fmt.Errorf("failed to authenticate with the target redis server") + } + } + + clientToSelfConn := NewRedisConn(clientConn) + defer clientToSelfConn.Close() + + p.relayHandler = NewRelayHandler(clientToSelfConn, selfToClientRedisConn, p.config.SessionLogger) + return p.relayHandler.Handle(ctx) +} diff --git a/packages/pam/handlers/redis/relay_handler.go b/packages/pam/handlers/redis/relay_handler.go new file mode 100644 index 00000000..24d2d19c --- /dev/null +++ b/packages/pam/handlers/redis/relay_handler.go @@ -0,0 +1,241 @@ +package redis + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "os" + "strings" + "sync/atomic" + "time" + + "github.com/Infisical/infisical-merge/packages/pam/session" + "github.com/rs/zerolog/log" + "github.com/smallnest/resp3" +) + +// RelayHandler handles relaying commands and responses between client and server +type RelayHandler struct { + clientToSelfConn *RedisConn + selfToServerConn *RedisConn + sessionLogger session.SessionLogger +} + +type LogType string + +const ( + LogTypeCmd LogType = "cmd" + LogTypePush LogType = "push" + LogTypeMonitor LogType = "monitor" +) + +type RedisLogEntry struct { + LogType LogType `json:"type"` + Cmd interface{} `json:"cmd,omitempty"` +} + +type serverReply struct { + value *resp3.Value + err error +} + +// NewRelayHandler creates a new relay handler +func NewRelayHandler(clientToSelfConn *RedisConn, selfToServerConn *RedisConn, sessionLogger session.SessionLogger) *RelayHandler { + return &RelayHandler{ + clientToSelfConn: clientToSelfConn, + selfToServerConn: selfToServerConn, + sessionLogger: sessionLogger, + } +} + +func (h *RelayHandler) Handle(ctx context.Context) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + monitorMode := atomic.Bool{} + serverReplyCh := make(chan serverReply, 1) + + go func(ch chan<- serverReply) { + for { + if err := ctx.Err(); err != nil { + return + } + if err := h.selfToServerConn.conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { + log.Error().Err(err).Msg("failed to set read deadline") + ch <- serverReply{nil, err} + return + } + v, _, err := h.selfToServerConn.Reader().ReadValue() + if err != nil { + if errors.Is(err, os.ErrDeadlineExceeded) { + continue + } + log.Error().Err(err).Msg("Error reading from server") + ch <- serverReply{nil, err} + return + } + + if monitorMode.Load() { + // In the monitoring mode, the server will keep sending simple string log values like + // + // +1766623914.688074 [0 127.0.0.1:33956] "AUTH" "(redacted)" + // + // We need to forward them all to the client. Other than that, we treat them all as + // server reply + if v.Type == resp3.TypeSimpleString && v.Str != "OK" { + err = h.clientToSelfConn.WriteValue(v, true) + if err != nil { + log.Error().Err(err).Msg("Error forwarding monitoring logs to the client") + ch <- serverReply{nil, err} + return + } + h.writeLogEntry(LogTypeMonitor, nil, v) + continue + } + } else if (v.Type == resp3.TypeArray && len(v.Elems) > 0 && strings.ToLower(v.Elems[0].Str) == "message") || + (v.Type == resp3.TypePush) { + // pubsub in resp2/resp3 mode will send a push as the confirmation instead of return anything, + // we need to treat that as a cmd reply otherwise the main loop will wait forever for the + // server reply to forward + if !isPubSubConfirmation(v) { + err = h.clientToSelfConn.WriteValue(v, true) + if err != nil { + log.Error().Err(err).Msg("Error forwarding push messages to the client") + ch <- serverReply{nil, err} + return + } + h.writeLogEntry(LogTypePush, nil, v) + continue + } + } + select { + case ch <- serverReply{v, nil}: + case <-ctx.Done(): + return + } + } + }(serverReplyCh) + + for { + if ctx.Err() != nil { + return ctx.Err() + } + err := h.clientToSelfConn.conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + if err != nil { + return err + } + value, _, err := h.clientToSelfConn.Reader().ReadValue() + if err != nil { + if errors.Is(err, os.ErrDeadlineExceeded) { + continue + } + return err + } + switch value.Type { + case resp3.TypeArray: + cmd := value.Elems[0] + if cmd.Type != resp3.TypeBlobString { + return fmt.Errorf("expected SimpleString, got %s", cmd.Type) + } + cmdStr := strings.ToLower(value.Elems[0].Str) + switch cmdStr { + // Handle auth command, we just reply OK instead of forwarding it to the server + case "auth": + err := h.clientToSelfConn.WriteValue(&resp3.Value{Type: resp3.TypeSimpleString, Str: "OK"}, true) + if err != nil { + return err + } + break + // Forward all other commands + default: + err := h.selfToServerConn.WriteValue(value, true) + if err != nil { + h.writeLogEntry(LogTypeCmd, value, nil) + return err + } + + if cmdStr == "monitor" { + // We need to turn on the monitor flag before we read the reply, + // otherwise some log msg might be treated as server reply + monitorMode.Store(true) + } + + reply := <-serverReplyCh + if reply.err != nil { + return reply.err + } + h.writeLogEntry(LogTypeCmd, value, reply.value) + err = h.clientToSelfConn.WriteValue(reply.value, true) + if err != nil { + return err + } + + if cmdStr == "monitor" && reply.value.Str != "OK" { + // looks like monitor cmd failed, let's revert the monitoring flag back + monitorMode.Store(false) + } + if cmdStr == "reset" && reply.value.Str == "OK" { + // the connection is reset, and we have exited monitor mode as well + monitorMode.Store(false) + // TODO: with reset cmd, should we send out AUTH again automatically to the server? + } + } + default: + if err = h.clientToSelfConn.WriteValue(&resp3.Value{Type: resp3.TypeSimpleError, Err: fmt.Sprintf("Unexpected value type %v", value.Type)}, true); err != nil { + return err + } + return fmt.Errorf("unexpected value type %v", value.Type) + } + } +} + +func (r *RelayHandler) writeLogEntry(logType LogType, cmd *resp3.Value, resp *resp3.Value) { + entry := RedisLogEntry{ + LogType: logType, + } + if logType == LogTypeCmd { + entry.Cmd = cmd.SmartResult() + } + input, err := valueToJson(entry) + if err != nil { + log.Error().Err(err).Msg("failed to convert cmd value to json") + return + } + output := "" + if resp != nil { + output, err = valueToJson(resp.SmartResult()) + if err != nil { + log.Error().Err(err).Msg("failed to convert resp value to json") + return + } + } + + err = r.sessionLogger.LogEntry(session.SessionLogEntry{ + Timestamp: time.Now(), + Input: input, + Output: output, + }) + if err != nil { + log.Error().Err(err).Msg("failed to write log entry to file") + } +} + +func valueToJson(value interface{}) (string, error) { + data, err := json.Marshal(value) + if err != nil { + return "", err + } + return string(data), nil +} + +func isPubSubConfirmation(value *resp3.Value) bool { + if len(value.Elems) < 1 { + return false + } + switch strings.ToLower(value.Elems[0].Str) { + case "subscribe", "psubscribe", "ssubscribe", "unsubscribe", "punsubscribe", "sunsubscribe": + return true + } + return false +} diff --git a/packages/pam/local/redis-proxy.go b/packages/pam/local/redis-proxy.go new file mode 100644 index 00000000..10b06d35 --- /dev/null +++ b/packages/pam/local/redis-proxy.go @@ -0,0 +1,347 @@ +package pam + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "os" + "os/signal" + "strings" + "syscall" + "time" + + "github.com/Infisical/infisical-merge/packages/api" + "github.com/Infisical/infisical-merge/packages/config" + "github.com/Infisical/infisical-merge/packages/pam/session" + "github.com/Infisical/infisical-merge/packages/util" + "github.com/go-resty/resty/v2" + "github.com/manifoldco/promptui" + "github.com/rs/zerolog/log" +) + +type RedisProxyServer struct { + BaseProxyServer // Embed common functionality + server net.Listener + port int +} + +func StartRedisLocalProxy(accessToken string, accountPath string, projectID string, durationStr string, port int) { + log.Info().Msgf("Starting Redis proxy for account: %s", accountPath) + log.Info().Msgf("Session duration: %s", durationStr) + + httpClient := resty.New() + httpClient.SetAuthToken(accessToken) + httpClient.SetHeader("User-Agent", "infisical-cli") + + pamRequest := api.PAMAccessRequest{ + Duration: durationStr, + AccountPath: accountPath, + ProjectId: projectID, + } + + pamResponse, err := api.CallPAMAccess(httpClient, pamRequest) + if err != nil { + var apiErr *api.APIError + if errors.As(err, &apiErr) && apiErr.ErrorMessage == "A policy is in place for this resource" { + if v, ok := apiErr.Details.(map[string]any); ok { + log.Info().Msgf("Account is protected by approval policy: %s", v["policyName"]) + + shouldSendRequest, err := askForApprovalRequestTrigger() + if err != nil { + if errors.Is(err, promptui.ErrAbort) { + log.Info().Msgf("Approval request was not created.") + } else { + util.HandleError(err, "Failed to send PAM account request") + } + return + } + + if !shouldSendRequest { + log.Info().Msgf("Approval request was not created.") + return + } + + approvalReq, err := api.CallPAMAccessApprovalRequest(httpClient, api.PAMAccessApprovalRequest{ + ProjectId: projectID, + RequestData: api.PAMAccessApprovalRequestPayloadRequestData{ + AccountPath: accountPath, + AccessDuration: durationStr, + }, + }) + if err != nil { + util.HandleError(err, "Failed to send PAM account request") + return + } + + url := fmt.Sprintf("%s/organizations/%s/projects/pam/%s/approval-requests/%s", strings.TrimSuffix(config.INFISICAL_URL, "/api"), approvalReq.Request.OrgId, approvalReq.Request.ProjectId, approvalReq.Request.ID) + if err := util.OpenBrowser(url); err != nil { + log.Error().Msgf("Failed to do browser redirect: %v", err) + } + log.Info().Msgf("Approval request created.") + log.Info().Msgf("View details at: %s", url) + return + } + } + + util.HandleError(err, "Failed to access PAM account") + return + } + + // Verify this is a Redis resource + if pamResponse.ResourceType != session.ResourceTypeRedis { + util.HandleError(fmt.Errorf("account is not a Redis resource, got: %s", pamResponse.ResourceType), "Invalid resource type") + return + } + + log.Info().Msgf("Redis session created with ID: %s", pamResponse.SessionId) + + duration, err := time.ParseDuration(durationStr) + if err != nil { + util.HandleError(err, "Failed to parse duration") + return + } + + ctx, cancel := context.WithCancel(context.Background()) + + proxy := &RedisProxyServer{ + BaseProxyServer: BaseProxyServer{ + httpClient: httpClient, + relayHost: pamResponse.RelayHost, + relayClientCert: pamResponse.RelayClientCertificate, + relayClientKey: pamResponse.RelayClientPrivateKey, + relayServerCertChain: pamResponse.RelayServerCertificateChain, + gatewayClientCert: pamResponse.GatewayClientCertificate, + gatewayClientKey: pamResponse.GatewayClientPrivateKey, + gatewayServerCertChain: pamResponse.GatewayServerCertificateChain, + sessionExpiry: time.Now().Add(duration), + sessionId: pamResponse.SessionId, + resourceType: pamResponse.ResourceType, + ctx: ctx, + cancel: cancel, + shutdownCh: make(chan struct{}), + }, + } + + if err := proxy.ValidateResourceTypeSupported(); err != nil { + util.HandleError(err, "Gateway version outdated") + return + } + + err = proxy.Start(port) + if err != nil { + util.HandleError(err, "Failed to start proxy server") + return + } + + if port == 0 { + fmt.Printf("Redis proxy started for account %s with duration %s on port %d (auto-assigned)\n", accountPath, duration.String(), proxy.port) + } else { + fmt.Printf("Redis proxy started for account %s with duration %s on port %d\n", accountPath, duration.String(), proxy.port) + } + + username, ok := pamResponse.Metadata["username"] + if !ok { + username = "" // Redis may not always have username + } + accountName, ok := pamResponse.Metadata["accountName"] + if !ok { + util.HandleError(fmt.Errorf("PAM response metadata is missing 'accountName'"), "Failed to start proxy server") + return + } + accountPathMetadata, ok := pamResponse.Metadata["accountPath"] + if !ok { + util.HandleError(fmt.Errorf("PAM response metadata is missing 'accountPath'"), "Failed to start proxy server") + return + } + + log.Info().Msgf("Redis proxy server listening on port %d", proxy.port) + fmt.Printf("\n") + fmt.Printf("**********************************************************************\n") + fmt.Printf(" Redis Proxy Session Started! \n") + fmt.Printf("----------------------------------------------------------------------\n") + fmt.Printf("Accessing account %s at folder path %s\n", accountName, accountPathMetadata) + fmt.Printf("\n") + fmt.Printf("You can now connect to your Redis instance using:\n") + if username != "" { + fmt.Printf("redis://%s@localhost:%d", username, proxy.port) + } else { + fmt.Printf("redis://localhost:%d", proxy.port) + } + fmt.Printf("\n**********************************************************************\n") + fmt.Printf("\n") + + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + go func() { + sig := <-sigChan + log.Info().Msgf("Received signal %v, initiating graceful shutdown...", sig) + proxy.gracefulShutdown() + }() + + proxy.Run() +} + +func (p *RedisProxyServer) Start(port int) error { + var err error + if port == 0 { + p.server, err = net.Listen("tcp", ":0") + } else { + p.server, err = net.Listen("tcp", fmt.Sprintf(":%d", port)) + } + + if err != nil { + return fmt.Errorf("failed to start server: %w", err) + } + + addr := p.server.Addr().(*net.TCPAddr) + p.port = addr.Port + + return nil +} + +func (p *RedisProxyServer) gracefulShutdown() { + p.shutdownOnce.Do(func() { + log.Info().Msg("Starting graceful shutdown of Redis proxy...") + + // Send session termination notification before cancelling context + p.NotifySessionTermination() + + // Signal the accept loop to stop + close(p.shutdownCh) + + // Close the server to stop accepting new connections + if p.server != nil { + p.server.Close() + } + + // Cancel context to signal all goroutines to stop + p.cancel() + + // Wait for connections to close + p.WaitForConnectionsWithTimeout(10 * time.Second) + + log.Info().Msg("Redis proxy shutdown complete") + os.Exit(0) + }) +} + +func (p *RedisProxyServer) Run() { + defer p.server.Close() + + for { + select { + case <-p.ctx.Done(): + log.Info().Msg("Context cancelled, stopping proxy server") + return + case <-p.shutdownCh: + log.Info().Msg("Shutdown signal received, stopping proxy server") + return + default: + // Check if session has expired + if time.Now().After(p.sessionExpiry) { + log.Warn().Msg("Redis session expired, shutting down proxy") + p.gracefulShutdown() + return + } + + if tcpListener, ok := p.server.(*net.TCPListener); ok { + tcpListener.SetDeadline(time.Now().Add(1 * time.Second)) + } + + conn, err := p.server.Accept() + if err != nil { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + continue + } + select { + case <-p.ctx.Done(): + return + case <-p.shutdownCh: + return + default: + log.Error().Err(err).Msg("Failed to accept connection") + continue + } + } + + // Track active connection + p.activeConnections.Add(1) + go p.handleConnection(conn) + } + } +} + +func (p *RedisProxyServer) handleConnection(clientConn net.Conn) { + defer func() { + clientConn.Close() + p.activeConnections.Done() + }() + + log.Info().Msgf("New connection from %s", clientConn.RemoteAddr()) + + select { + case <-p.ctx.Done(): + log.Info().Msg("Context cancelled, closing connection immediately") + return + default: + } + + relayConn, err := p.CreateRelayConnection() + if err != nil { + log.Error().Err(err).Msg("Failed to connect to relay") + return + } + defer relayConn.Close() + + gatewayConn, err := p.CreateGatewayConnection(relayConn, ALPNInfisicalPAMProxy) + if err != nil { + log.Error().Err(err).Msg("Failed to connect to gateway") + return + } + defer gatewayConn.Close() + + log.Info().Msg("Established connection to Redis resource") + + connCtx, connCancel := context.WithCancel(p.ctx) + defer connCancel() + + errCh := make(chan error, 2) + + // Bidirectional data forwarding with context cancellation + go func() { + defer connCancel() + _, err := io.Copy(clientConn, gatewayConn) + if err != nil { + select { + case <-connCtx.Done(): + default: + log.Debug().Err(err).Msg("Gateway to client copy ended") + } + } + errCh <- err + }() + + go func() { + defer connCancel() + _, err := io.Copy(gatewayConn, clientConn) + if err != nil { + select { + case <-connCtx.Done(): + default: + log.Debug().Err(err).Msg("Client to gateway copy ended") + } + } + errCh <- err + }() + + select { + case <-errCh: + case <-connCtx.Done(): + log.Info().Msg("Connection cancelled by context") + } + + log.Info().Msgf("Connection closed for client: %s", clientConn.RemoteAddr().String()) +} diff --git a/packages/pam/pam-proxy.go b/packages/pam/pam-proxy.go index ae272ad7..8e464901 100644 --- a/packages/pam/pam-proxy.go +++ b/packages/pam/pam-proxy.go @@ -12,6 +12,7 @@ import ( "github.com/Infisical/infisical-merge/packages/pam/handlers" "github.com/Infisical/infisical-merge/packages/pam/handlers/kubernetes" "github.com/Infisical/infisical-merge/packages/pam/handlers/mysql" + "github.com/Infisical/infisical-merge/packages/pam/handlers/redis" "github.com/Infisical/infisical-merge/packages/pam/handlers/ssh" "github.com/Infisical/infisical-merge/packages/pam/session" "github.com/go-resty/resty/v2" @@ -37,6 +38,7 @@ func GetSupportedResourceTypes() []string { session.ResourceTypeMysql, session.ResourceTypeSSH, session.ResourceTypeKubernetes, + session.ResourceTypeRedis, } } @@ -207,6 +209,24 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo Bool("sslEnabled", credentials.SSLEnabled). Msg("Starting MySQL PAM proxy") return proxy.HandleConnection(ctx, conn) + case session.ResourceTypeRedis: + redisConfig := redis.RedisProxyConfig{ + TargetAddr: fmt.Sprintf("%s:%d", credentials.Host, credentials.Port), + InjectUsername: credentials.Username, + InjectPassword: credentials.Password, + EnableTLS: credentials.SSLEnabled, + TLSConfig: tlsConfig, + SessionID: pamConfig.SessionId, + SessionLogger: sessionLogger, + } + + proxy := redis.NewRedisProxy(redisConfig) + log.Info(). + Str("sessionId", pamConfig.SessionId). + Str("target", redisConfig.TargetAddr). + Bool("sslEnabled", credentials.SSLEnabled). + Msg("Starting Redis PAM proxy") + return proxy.HandleConnection(ctx, conn) case session.ResourceTypeSSH: sshConfig := ssh.SSHProxyConfig{ TargetAddr: fmt.Sprintf("%s:%d", credentials.Host, credentials.Port), diff --git a/packages/pam/session/uploader.go b/packages/pam/session/uploader.go index 100da783..87139a53 100644 --- a/packages/pam/session/uploader.go +++ b/packages/pam/session/uploader.go @@ -24,6 +24,7 @@ var ErrSessionFileNotFound = errors.New("session file not found") const ( ResourceTypePostgres = "postgres" ResourceTypeMysql = "mysql" + ResourceTypeRedis = "redis" ResourceTypeSSH = "ssh" ResourceTypeKubernetes = "kubernetes" ) @@ -54,7 +55,7 @@ func NewSessionUploader(httpClient *resty.Client, credentialsManager *Credential func ParseSessionFilename(filename string) (*SessionFileInfo, error) { // Try new format first: pam_session_{sessionID}_{resourceType}_expires_{timestamp}.enc // Build regex pattern using constants - resourceTypePattern := fmt.Sprintf("(%s|%s|%s|%s)", ResourceTypeSSH, ResourceTypePostgres, ResourceTypeMysql, ResourceTypeKubernetes) + resourceTypePattern := fmt.Sprintf("(%s|%s|%s|%s|%s)", ResourceTypeSSH, ResourceTypePostgres, ResourceTypeRedis, ResourceTypeMysql, ResourceTypeKubernetes) newFormatRegex := regexp.MustCompile(fmt.Sprintf(`^pam_session_(.+)_%s_expires_(\d+)\.enc$`, resourceTypePattern)) matches := newFormatRegex.FindStringSubmatch(filename)