diff --git a/go.mod b/go.mod index 015be2a1..60de16d9 100644 --- a/go.mod +++ b/go.mod @@ -6,11 +6,13 @@ toolchain go1.23.5 require ( github.com/BobuSumisu/aho-corasick v1.0.3 + github.com/Infisical/sql-query-identifier v0.0.0-20251118234314-f02b443269f4 github.com/Masterminds/sprig/v3 v3.3.0 github.com/bradleyjkemp/cupaloy/v2 v2.8.0 github.com/charmbracelet/lipgloss v0.9.1 github.com/creack/pty v1.1.21 github.com/denisbrodbeck/machineid v1.0.1 + github.com/dgraph-io/badger/v3 v3.2103.5 github.com/fatih/semgroup v1.2.0 github.com/gitleaks/go-gitdiff v0.9.1 github.com/go-mysql-org/go-mysql v1.13.0 @@ -77,7 +79,6 @@ require ( github.com/chzyer/readline v1.5.1 // indirect github.com/danieljoos/wincred v1.2.0 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect - github.com/dgraph-io/badger/v3 v3.2103.5 // indirect github.com/dgraph-io/ristretto v0.1.1 // indirect github.com/dustin/go-humanize v1.0.0 // indirect github.com/dvsekhvalnov/jose2go v1.6.0 // indirect diff --git a/go.sum b/go.sum index c3f0e0fb..15f8ca47 100644 --- a/go.sum +++ b/go.sum @@ -55,6 +55,8 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03 github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/Infisical/go-keyring v1.0.2 h1:dWOkI/pB/7RocfSJgGXbXxLDcVYsdslgjEPmVhb+nl8= github.com/Infisical/go-keyring v1.0.2/go.mod h1:LWOnn/sw9FxDW/0VY+jHFAfOFEe03xmwBVSfJnBowto= +github.com/Infisical/sql-query-identifier v0.0.0-20251118234314-f02b443269f4 h1:xU/V9xG03uXBx3Ibsggh3xnavIJp4ZEKs0bDzBu9zHE= +github.com/Infisical/sql-query-identifier v0.0.0-20251118234314-f02b443269f4/go.mod h1:okzj7syePKK5CZjvWF6POKfpuzmQYgjKf+f+MliRFUs= github.com/Infisical/turn/v4 v4.0.1 h1:omdelNsnFfzS5cu86W5OBR68by68a8sva4ogR0lQQnw= github.com/Infisical/turn/v4 v4.0.1/go.mod h1:pMMKP/ieNAG/fN5cZiN4SDuyKsXtNTr0ccN7IToA1zs= github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI= @@ -63,6 +65,7 @@ github.com/Masterminds/semver/v3 v3.3.0 h1:B8LGeaivUe71a5qox1ICM/JLl0NqZSW5CHyL+ github.com/Masterminds/semver/v3 v3.3.0/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM= github.com/Masterminds/sprig/v3 v3.3.0 h1:mQh0Yrg1XPo6vjYXgtf5OtijNAKJRNcTdOOGZe3tPhs= github.com/Masterminds/sprig/v3 v3.3.0/go.mod h1:Zy1iXRYNqNLUolqCpL4uhk6SHUMAOSCzdgBfDb35Lz0= +github.com/OneOfOne/xxhash v1.2.2 h1:KMrpdQIwFcEqXDklaen+P1axHaj9BSKzvpUUfnHldSE= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/alessio/shellescape v1.4.1 h1:V7yhSDDn8LP4lc4jS8pFkt0zCnzVJlG5JXy9BVKJUX0= github.com/alessio/shellescape v1.4.1/go.mod h1:PZAiSCk0LJaZkiCSkPv8qIobYglO3FPpyFjDCtHLS30= @@ -150,6 +153,7 @@ github.com/dgraph-io/badger/v3 v3.2103.5 h1:ylPa6qzbjYRQMU6jokoj4wzcaweHylt//CH0 github.com/dgraph-io/badger/v3 v3.2103.5/go.mod h1:4MPiseMeDQ3FNCYwRbbcBOGJLf5jsE0PPFzRiKjtcdw= github.com/dgraph-io/ristretto v0.1.1 h1:6CWw5tJNgpegArSHpNHJKldNeq03FQCwYvfMVWajOK8= github.com/dgraph-io/ristretto v0.1.1/go.mod h1:S1GPSBCYCIhmVNfcth17y2zZtQT6wzkzgwUve0VDWWA= +github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 h1:tdlZCpZ/P9DhczCTSixgIKmwPv6+wP5DGjqLYw5SUiA= github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= @@ -516,6 +520,7 @@ github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9 github.com/sony/gobreaker v0.5.0 h1:dRCvqm0P490vZPmy7ppEk2qCnCieBooFJ+YoXGYB+yg= github.com/sony/gobreaker v0.5.0/go.mod h1:ZKptC7FHNvhBz7dN2LGjPVBz2sZJmc0/PkyDJOjmxWY= github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= +github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI= github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ= github.com/spf13/afero v1.6.0 h1:xoax2sJ2DT8S8xA2paPFjDCScCNeWsg75VG0DLRreiY= diff --git a/packages/api/model.go b/packages/api/model.go index 5b78305a..08aa181e 100644 --- a/packages/api/model.go +++ b/packages/api/model.go @@ -802,6 +802,7 @@ type PAMSessionCredentials struct { SSLCertificate string `json:"sslCertificate,omitempty"` Username string `json:"username"` Password string `json:"password"` + ReadOnlyMode bool `json:"readOnlyMode"` AuthMethod string `json:"authMethod,omitempty"` PrivateKey string `json:"privateKey,omitempty"` } diff --git a/packages/pam/handlers/mysql/proxy.go b/packages/pam/handlers/mysql/proxy.go index f999e246..fa660442 100644 --- a/packages/pam/handlers/mysql/proxy.go +++ b/packages/pam/handlers/mysql/proxy.go @@ -4,12 +4,13 @@ import ( "context" "crypto/tls" "fmt" + "net" + "github.com/Infisical/infisical-merge/packages/pam/session" "github.com/go-mysql-org/go-mysql/client" "github.com/go-mysql-org/go-mysql/mysql" "github.com/go-mysql-org/go-mysql/server" "github.com/rs/zerolog/log" - "net" ) // TODO: DRY with psql? @@ -22,6 +23,7 @@ type MysqlProxyConfig struct { TLSConfig *tls.Config SessionID string SessionLogger session.SessionLogger + ReadOnlyMode bool } type MysqlProxy struct { @@ -60,15 +62,14 @@ func (p *MysqlProxy) HandleConnection(ctx context.Context, clientConn net.Conn) defer selfServerConn.Close() actualServer := server.NewServer( - // Let's use a conservative version to let the client not to throw - // many too fancy stuff at us to get the V1 out of door fast + // smaller version to prevent complex errors "8.0.11", mysql.DEFAULT_COLLATION_ID, mysql.AUTH_NATIVE_PASSWORD, nil, nil, ) - p.relayHandler = NewRelayHandler(selfServerConn, p.config.SessionLogger) + p.relayHandler = NewRelayHandler(selfServerConn, p.config.SessionLogger, p.config) clientSelfConn, err := actualServer.NewCustomizedConn( clientConn, &AnyUserCredentialProvider{}, @@ -86,6 +87,13 @@ func (p *MysqlProxy) HandleConnection(ctx context.Context, clientConn net.Conn) } }() + // if in read-only mode, set the session to be a read-only transaction + if p.config.ReadOnlyMode { + if err := p.setSessionReadOnly(selfServerConn); err != nil { + return err + } + } + for !clientSelfConn.Closed() && !p.relayHandler.Closed() { err = clientSelfConn.HandleCommand() if err != nil { @@ -117,3 +125,16 @@ func (p *MysqlProxy) connectToServer() (*client.Conn, error) { } return conn, nil } + +func (p *MysqlProxy) setSessionReadOnly(serverConn *client.Conn) error { + log.Info().Str("sessionID", p.config.SessionID).Msg("Setting session to read-only transaction mode") + + _, err := serverConn.Execute("SET SESSION TRANSACTION READ ONLY;") + if err != nil { + log.Error().Err(err).Str("sessionID", p.config.SessionID).Msg("Failed to set session to read-only mode") + return fmt.Errorf("failed to set session to read-only: %w", err) + } + + log.Debug().Str("sessionID", p.config.SessionID).Msg("Session set to read-only successfully") + return nil +} diff --git a/packages/pam/handlers/mysql/relay_handler.go b/packages/pam/handlers/mysql/relay_handler.go index c3d2affd..7512443b 100644 --- a/packages/pam/handlers/mysql/relay_handler.go +++ b/packages/pam/handlers/mysql/relay_handler.go @@ -6,6 +6,7 @@ import ( "time" "github.com/Infisical/infisical-merge/packages/pam/session" + sqi "github.com/Infisical/sql-query-identifier" "github.com/go-mysql-org/go-mysql/client" "github.com/go-mysql-org/go-mysql/mysql" "github.com/pkg/errors" @@ -16,10 +17,11 @@ type RelayHandler struct { selfServerConn *client.Conn sessionLogger session.SessionLogger closed atomic.Bool + config MysqlProxyConfig } -func NewRelayHandler(selfServerConn *client.Conn, sessionLogger session.SessionLogger) *RelayHandler { - return &RelayHandler{selfServerConn, sessionLogger, atomic.Bool{}} +func NewRelayHandler(selfServerConn *client.Conn, sessionLogger session.SessionLogger, config MysqlProxyConfig) *RelayHandler { + return &RelayHandler{selfServerConn, sessionLogger, atomic.Bool{}, config} } func (r *RelayHandler) Closed() bool { @@ -33,6 +35,15 @@ func (r *RelayHandler) UseDB(dbName string) error { } func (r *RelayHandler) HandleQuery(query string) (*mysql.Result, error) { + if r.config.ReadOnlyMode { + isReadOnly, err := r.handleReadOnlyCheck(query) + if err != nil { + return nil, err + } + if !isReadOnly { + return nil, mysql.NewError(mysql.ER_OPTION_PREVENTS_STATEMENT, "Operation not allowed by policy in read-only mode.") + } + } result, err := r.selfServerConn.Execute(query) r.checkConnLostError(err) if err != nil { @@ -59,6 +70,15 @@ func (r *RelayHandler) HandleFieldList(table string, fieldWildcard string) ([]*m } func (r *RelayHandler) HandleStmtPrepare(query string) (params int, columns int, context interface{}, err error) { + if r.config.ReadOnlyMode { + isReadOnly, err := r.handleReadOnlyCheck(query) + if err != nil { + return 0, 0, nil, err + } + if !isReadOnly { + return 0, 0, nil, mysql.NewError(mysql.ER_OPTION_PREVENTS_STATEMENT, "Operation not allowed by policy in read-only mode.") + } + } stmt, err := r.selfServerConn.Prepare(query) r.checkConnLostError(err) if err != nil { @@ -68,6 +88,15 @@ func (r *RelayHandler) HandleStmtPrepare(query string) (params int, columns int, } func (r *RelayHandler) HandleStmtExecute(context interface{}, query string, args []interface{}) (*mysql.Result, error) { + if r.config.ReadOnlyMode { + isReadOnly, err := r.handleReadOnlyCheck(query) + if err != nil { + return nil, err + } + if !isReadOnly { + return nil, mysql.NewError(mysql.ER_OPTION_PREVENTS_STATEMENT, "Operation not allowed by policy in read-only mode.") + } + } stmt := context.(*client.Stmt) result, err := stmt.Execute(args...) r.checkConnLostError(err) @@ -114,6 +143,43 @@ func (r *RelayHandler) writeLogEntry(entry session.SessionLogEntry) (*mysql.Resu return nil, nil } +func (r *RelayHandler) handleReadOnlyCheck(query string) (bool, error) { + if query == "" { + return true, nil + } + + dialect := sqi.DialectMySQL + strict := false + options := sqi.IdentifyOptions{ + Dialect: &dialect, + Strict: &strict, + } + + identifiedQueries, err := sqi.Identify(query, options) + if err != nil { + log.Error(). + Str("sessionID", r.config.SessionID). + Str("query", query). + Err(err). + Msg("Failed to identify query; blocking in read-only mode.") + return false, err + } + + // verify that every statement in the query is read-only + for _, identifiedQuery := range identifiedQueries { + if identifiedQuery.ExecutionType != sqi.ExecutionListing && identifiedQuery.ExecutionType != sqi.ExecutionInformation { + log.Warn(). + Str("sessionID", r.config.SessionID). + Str("query", query). + Str("executionType", string(identifiedQuery.ExecutionType)). + Msg("Write query blocked in read-only mode.") + return false, nil + } + } + + return true, nil +} + func formatResult(result *mysql.Result) string { if result.Resultset != nil { return fmt.Sprintf("SUCCESS (%d rows affected)", len(result.Resultset.Values)) diff --git a/packages/pam/handlers/postgres.go b/packages/pam/handlers/postgres.go index fbf0b820..23587042 100644 --- a/packages/pam/handlers/postgres.go +++ b/packages/pam/handlers/postgres.go @@ -19,6 +19,8 @@ import ( "github.com/jackc/pgx/v5/pgproto3" "github.com/rs/zerolog/log" "golang.org/x/crypto/pbkdf2" // TODO: Remove this once we update to go 1.25.1 or later where it's already in the standard library + + sqi "github.com/Infisical/sql-query-identifier" ) type PostgresProxyConfig struct { @@ -30,6 +32,7 @@ type PostgresProxyConfig struct { TLSConfig *tls.Config SessionID string SessionLogger session.SessionLogger + ReadOnlyMode bool } type PostgresProxy struct { @@ -97,6 +100,13 @@ func (p *PostgresProxy) HandleConnection(ctx context.Context, clientConn net.Con return fmt.Errorf("startup failed: %w", err) } + // if in read-only mode, set the session to be a read-only transaction + if p.config.ReadOnlyMode { + if err := p.setSessionReadOnly(clientBackend, serverFrontend); err != nil { + return err + } + } + // Proxy messages bidirectionally errChan := make(chan error, 2) @@ -668,6 +678,134 @@ func (p *PostgresProxy) handleMD5PasswordAsProxy(clientBackend *pgproto3.Backend } } +func (p *PostgresProxy) setSessionReadOnly(clientBackend *pgproto3.Backend, serverFrontend *pgproto3.Frontend) error { + log.Info().Str("sessionID", p.config.SessionID).Msg("Setting session to read-only transaction mode") + + readOnlyQuery := &pgproto3.Query{String: "SET SESSION CHARACTERISTICS AS TRANSACTION READ ONLY;"} + serverFrontend.Send(readOnlyQuery) + if err := serverFrontend.Flush(); err != nil { + log.Error().Err(err).Str("sessionID", p.config.SessionID).Msg("Failed to send read-only session command") + return fmt.Errorf("flushing read-only command: %w", err) + } + + var sawCommandComplete, sawReadyForQuery bool + + for { + msg, err := serverFrontend.Receive() + if err != nil { + log.Error().Err(err).Str("sessionID", p.config.SessionID).Msg("Error receiving response for read-only command") + return fmt.Errorf("receiving response for read-only command: %w", err) + } + + log.Debug(). + Str("sessionID", p.config.SessionID). + Str("msgType", fmt.Sprintf("%T", msg)). + Msg("← SERVER (read-only setup)") + + switch m := msg.(type) { + case *pgproto3.CommandComplete: + sawCommandComplete = true + case *pgproto3.ReadyForQuery: + sawReadyForQuery = true + case *pgproto3.ParameterStatus: + continue + case *pgproto3.BackendKeyData: + continue + case *pgproto3.ErrorResponse: + log.Error(). + Str("sessionID", p.config.SessionID). + Str("code", m.Code). + Str("message", m.Message). + Msg("Server returned an error when setting session to read-only mode") + + clientBackend.Send(m) + if err := clientBackend.Flush(); err != nil { + log.Warn().Err(err).Str("sessionID", p.config.SessionID).Msg("Failed to flush error response to client during read-only setup") + } + + return fmt.Errorf("server error on setting read-only mode: %s", m.Message) + default: + log.Warn(). + Str("sessionID", p.config.SessionID). + Str("msgType", fmt.Sprintf("%T", m)). + Msg("Received unexpected message during read-only setup") + + return fmt.Errorf("unexpected message type %T during read-only setup", m) + } + + if sawCommandComplete && sawReadyForQuery { + return nil + } + } +} + +func getQueryContentFromMessage(msg pgproto3.FrontendMessage) string { + switch m := msg.(type) { + case *pgproto3.Query: + return m.String + case *pgproto3.Parse: + return m.Query + default: + return "" + } +} + +func (p *PostgresProxy) handleReadOnlyCheck(msg pgproto3.FrontendMessage, clientBackend *pgproto3.Backend, errChan chan error) bool { + queryContent := getQueryContentFromMessage(msg) + if queryContent == "" { + return true + } + + dialect := sqi.DialectPSQL + strict := false + options := sqi.IdentifyOptions{ + Dialect: &dialect, + Strict: &strict, + } + + identifiedQueries, err := sqi.Identify(queryContent, options) + if err != nil { + log.Error(). + Str("sessionID", p.config.SessionID). + Str("query", queryContent). + Err(err). + Msg("Failed to identify query; blocking in read-only mode.") + + errorResponse := &pgproto3.ErrorResponse{ + Severity: "ERROR", + Code: "XX000", // internal_error + Message: "Failed to analyze query in read-only mode.", + } + clientBackend.Send(errorResponse) + _ = clientBackend.Flush() + errChan <- fmt.Errorf("failed to identify query, blocking: %s", queryContent) + return false + } + + // verify that every statement in the query is read-only + for _, identifiedQuery := range identifiedQueries { + if identifiedQuery.ExecutionType != sqi.ExecutionListing && identifiedQuery.ExecutionType != sqi.ExecutionInformation { + log.Warn(). + Str("sessionID", p.config.SessionID). + Str("query", queryContent). + Str("executionType", string(identifiedQuery.ExecutionType)). + Msg("Write query blocked in read-only mode.") + + errorResponse := &pgproto3.ErrorResponse{ + Severity: "ERROR", + Code: "42803", // insufficient_privilege + Message: "Operation not allowed by policy in read-only mode.", + } + clientBackend.Send(errorResponse) + _ = clientBackend.Flush() + errChan <- fmt.Errorf("write query blocked: %s (type: %s)", queryContent, identifiedQuery.ExecutionType) + return false + } + } + + return true +} + func (p *PostgresProxy) proxyClientToServer(clientBackend *pgproto3.Backend, serverFrontend *pgproto3.Frontend, errChan chan error) { for { msg, err := clientBackend.Receive() @@ -676,6 +814,10 @@ func (p *PostgresProxy) proxyClientToServer(clientBackend *pgproto3.Backend, ser return } + if p.config.ReadOnlyMode && !p.handleReadOnlyCheck(msg, clientBackend, errChan) { + return + } + p.trackClientMessage(msg) serverFrontend.Send(msg) diff --git a/packages/pam/pam-proxy.go b/packages/pam/pam-proxy.go index 78e55ca7..d79efe47 100644 --- a/packages/pam/pam-proxy.go +++ b/packages/pam/pam-proxy.go @@ -120,6 +120,7 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo TLSConfig: tlsConfig, SessionID: pamConfig.SessionId, SessionLogger: sessionLogger, + ReadOnlyMode: credentials.ReadOnlyMode, } proxy := handlers.NewPostgresProxy(proxyConfig) log.Info(). @@ -138,6 +139,7 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo TLSConfig: tlsConfig, SessionID: pamConfig.SessionId, SessionLogger: sessionLogger, + ReadOnlyMode: credentials.ReadOnlyMode, } proxy := mysql.NewMysqlProxy(mysqlConfig) diff --git a/packages/pam/session/credentials.go b/packages/pam/session/credentials.go index 3472bcf0..5f515535 100644 --- a/packages/pam/session/credentials.go +++ b/packages/pam/session/credentials.go @@ -21,6 +21,7 @@ type PAMCredentials struct { SSLEnabled bool SSLRejectUnauthorized bool SSLCertificate string + ReadOnlyMode bool } type cachedCredentials struct { @@ -92,6 +93,7 @@ func (cm *CredentialsManager) GetPAMSessionCredentials(sessionId string, expiryT SSLEnabled: response.Credentials.SSLEnabled, SSLRejectUnauthorized: response.Credentials.SSLRejectUnauthorized, SSLCertificate: response.Credentials.SSLCertificate, + ReadOnlyMode: response.Credentials.ReadOnlyMode, } cm.cacheMutex.Lock()