diff --git a/.env.example b/.env.example index 3dfa7e454..37822b1dc 100644 --- a/.env.example +++ b/.env.example @@ -79,6 +79,26 @@ REDIS_HOST=127.0.0.1 REDIS_PORT=6379 REDIS_PASSWORD=difyai123456 REDIS_DB=0 +REDIS_USE_SSL=false +# SSL configuration for Redis (when REDIS_USE_SSL=true) +REDIS_SSL_CERT_REQS=CERT_NONE +# REDIS_SSL_CERT_REQS controls how server certificates are verified: +# - CERT_NONE: Skips all certificate verification (insecure, sets InsecureSkipVerify=true) +# Use only in development/testing environments. This is the default in this example file. +# - CERT_OPTIONAL: Requires valid certificate verification (same as CERT_REQUIRED for client-side TLS) +# CERT_OPTIONAL is treated as CERT_REQUIRED because servers almost always present +# certificates, and the client's choice is whether to validate them or not +# Uses system's default CA certificates if REDIS_SSL_CA_CERTS is not provided +# - CERT_REQUIRED: Requires valid certificate verification (most secure, sets InsecureSkipVerify=false) +# Recommended for production environments +# IMPORTANT: REDIS_SSL_CA_CERTS must be provided, otherwise the application will fail to start +# - Empty string: Behaves like CERT_OPTIONAL (secure, enables verification, but allows system CA certificates) +# This is the default when REDIS_SSL_CERT_REQS is not set +REDIS_SSL_CA_CERTS= +# Path to the CA certificate file for SSL verification, e.g. /path/to/ca.crt +# REQUIRED when REDIS_SSL_CERT_REQS=CERT_REQUIRED +# Optional for CERT_OPTIONAL (uses system's default CA certificates if not provided) +# Ignored when REDIS_SSL_CERT_REQS=CERT_NONE # Whether to use Redis Sentinel mode. # If set to true, the application will automatically discover and connect to the master node through Sentinel. @@ -98,8 +118,8 @@ DB_PASSWORD=difyai123456 DB_HOST=localhost DB_PORT=5432 DB_DATABASE=dify_plugin -# Specifies the SSL mode for the database connection. -# Possible values include 'disable', 'require', 'verify-ca', and 'verify-full'. +# Specifies the SSL mode for the database connection. +# Possible values include 'disable', 'require', 'verify-ca', and 'verify-full'. # 'disable' means SSL is not used for the connection. DB_SSL_MODE=disable # database connection pool settings diff --git a/internal/cluster/clutser_test.go b/internal/cluster/clutser_test.go index 3980b6487..39295e80f 100644 --- a/internal/cluster/clutser_test.go +++ b/internal/cluster/clutser_test.go @@ -10,7 +10,7 @@ import ( ) func createSimulationCluster(nums int) ([]*Cluster, error) { - err := cache.InitRedisClient("0.0.0.0:6379", "", "difyai123456", false, 0) + err := cache.InitRedisClient("0.0.0.0:6379", "", "difyai123456", false, 0, nil) if err != nil { return nil, err } diff --git a/internal/core/debugging_runtime/server_test.go b/internal/core/debugging_runtime/server_test.go index 46d677a76..fff8f438a 100644 --- a/internal/core/debugging_runtime/server_test.go +++ b/internal/core/debugging_runtime/server_test.go @@ -113,7 +113,7 @@ func (n *TestPluginRuntimeNotifier) OnServerShutdown(reason ServerShutdownReason // TestAcceptConnection tests the acceptance of the connection func TestAcceptConnection(t *testing.T) { - if cache.InitRedisClient("0.0.0.0:6379", "", "difyai123456", false, 0) != nil { + if cache.InitRedisClient("0.0.0.0:6379", "", "difyai123456", false, 0, nil) != nil { t.Errorf("failed to init redis client") return } @@ -338,7 +338,7 @@ func TestNoHandleShakeIn10Seconds(t *testing.T) { } func TestIncorrectHandshake(t *testing.T) { - if cache.InitRedisClient("0.0.0.0:6379", "", "difyai123456", false, 0) != nil { + if cache.InitRedisClient("0.0.0.0:6379", "", "difyai123456", false, 0, nil) != nil { t.Errorf("failed to init redis client") return } diff --git a/internal/core/persistence/persistence_test.go b/internal/core/persistence/persistence_test.go index 9134c3136..c84791624 100644 --- a/internal/core/persistence/persistence_test.go +++ b/internal/core/persistence/persistence_test.go @@ -14,7 +14,7 @@ import ( ) func TestPersistenceStoreAndLoad(t *testing.T) { - err := cache.InitRedisClient("localhost:6379", "", "difyai123456", false, 0) + err := cache.InitRedisClient("localhost:6379", "", "difyai123456", false, 0, nil) if err != nil { t.Fatalf("Failed to init redis client: %v", err) } @@ -70,7 +70,7 @@ func TestPersistenceStoreAndLoad(t *testing.T) { } func TestPersistenceSaveAndLoadWithLongKey(t *testing.T) { - err := cache.InitRedisClient("localhost:6379", "", "difyai123456", false, 0) + err := cache.InitRedisClient("localhost:6379", "", "difyai123456", false, 0, nil) assert.Nil(t, err) defer cache.Close() db.Init(&app.Config{ @@ -104,7 +104,7 @@ func TestPersistenceSaveAndLoadWithLongKey(t *testing.T) { } func TestPersistenceDelete(t *testing.T) { - err := cache.InitRedisClient("localhost:6379", "", "difyai123456", false, 0) + err := cache.InitRedisClient("localhost:6379", "", "difyai123456", false, 0, nil) assert.Nil(t, err) defer cache.Close() db.Init(&app.Config{ @@ -151,7 +151,7 @@ func TestPersistenceDelete(t *testing.T) { } func TestPersistencePathTraversal(t *testing.T) { - err := cache.InitRedisClient("localhost:6379", "", "difyai123456", false, 0) + err := cache.InitRedisClient("localhost:6379", "", "difyai123456", false, 0, nil) if err != nil { t.Fatalf("Failed to init redis client: %v", err) } diff --git a/internal/core/plugin_manager/manager.go b/internal/core/plugin_manager/manager.go index 597439426..c6bde37ff 100644 --- a/internal/core/plugin_manager/manager.go +++ b/internal/core/plugin_manager/manager.go @@ -102,6 +102,12 @@ func (p *PluginManager) GetAsset(id string) ([]byte, error) { func (p *PluginManager) Launch(configuration *app.Config) { log.Info("start plugin manager daemon") + // Build TLS config for Redis (nil when RedisUseSsl=false) + tlsConf, err := configuration.RedisTLSConfig() + if err != nil { + log.Panic("invalid Redis TLS config: %s", err.Error()) + } + // init redis client if configuration.RedisUseSentinel { // use Redis Sentinel @@ -116,6 +122,7 @@ func (p *PluginManager) Launch(configuration *app.Config) { configuration.RedisUseSsl, configuration.RedisDB, configuration.RedisSentinelSocketTimeout, + tlsConf, // pass TLS to cache initializer ); err != nil { log.Panic("init redis sentinel client failed", "error", err) } @@ -126,6 +133,7 @@ func (p *PluginManager) Launch(configuration *app.Config) { configuration.RedisPass, configuration.RedisUseSsl, configuration.RedisDB, + tlsConf, // pass TLS to cache initializer ); err != nil { log.Panic("init redis client failed", "error", err) } diff --git a/internal/core/session_manager/session_trace_test.go b/internal/core/session_manager/session_trace_test.go index b282ce9bf..9d99b7fac 100644 --- a/internal/core/session_manager/session_trace_test.go +++ b/internal/core/session_manager/session_trace_test.go @@ -63,7 +63,7 @@ func TestGetSessionTraceInMemory(t *testing.T) { } func TestGetSessionTraceFromDistributedCache(t *testing.T) { - require.NoError(t, cache.InitRedisClient("127.0.0.1:6379", "", "difyai123456", false, 0)) + require.NoError(t, cache.InitRedisClient("127.0.0.1:6379", "", "difyai123456", false, 0, nil)) t.Cleanup(func() { cache.Close() }) diff --git a/internal/service/debugging_service/connection_key_test.go b/internal/service/debugging_service/connection_key_test.go index eefca3a13..8dbd6026c 100644 --- a/internal/service/debugging_service/connection_key_test.go +++ b/internal/service/debugging_service/connection_key_test.go @@ -8,7 +8,7 @@ import ( ) func TestConnectionKey(t *testing.T) { - err := cache.InitRedisClient("0.0.0.0:6379", "", "difyai123456", false, 0) + err := cache.InitRedisClient("0.0.0.0:6379", "", "difyai123456", false, 0, nil) if err != nil { t.Errorf("init redis client failed: %v", err) return diff --git a/internal/types/app/config.go b/internal/types/app/config.go index 656eea59d..26c324395 100644 --- a/internal/types/app/config.go +++ b/internal/types/app/config.go @@ -1,7 +1,11 @@ package app import ( + "crypto/tls" + "crypto/x509" "fmt" + "os" + "strings" "github.com/go-playground/validator/v10" "github.com/langgenius/dify-plugin-daemon/pkg/entities/plugin_entities" @@ -68,7 +72,7 @@ type Config struct { HuaweiOBSAccessKey string `envconfig:"HUAWEI_OBS_ACCESS_KEY"` HuaweiOBSSecretKey string `envconfig:"HUAWEI_OBS_SECRET_KEY"` HuaweiOBSServer string `envconfig:"HUAWEI_OBS_SERVER"` - HuaweiOBSPathStyle bool `envconfig:"HUAWEI_OBS_PATH_STYLE" default:"false"` + HuaweiOBSPathStyle bool `envconfig:"HUAWEI_OBS_PATH_STYLE" default:"false"` // volcengine tos VolcengineTOSEndpoint string `envconfig:"VOLCENGINE_TOS_ENDPOINT"` @@ -109,12 +113,14 @@ type Config struct { RoutinePoolSize int `envconfig:"ROUTINE_POOL_SIZE" validate:"required"` // redis - RedisHost string `envconfig:"REDIS_HOST"` - RedisPort uint16 `envconfig:"REDIS_PORT"` - RedisPass string `envconfig:"REDIS_PASSWORD"` - RedisUser string `envconfig:"REDIS_USERNAME"` - RedisUseSsl bool `envconfig:"REDIS_USE_SSL"` - RedisDB int `envconfig:"REDIS_DB"` + RedisHost string `envconfig:"REDIS_HOST"` + RedisPort uint16 `envconfig:"REDIS_PORT"` + RedisPass string `envconfig:"REDIS_PASSWORD"` + RedisUser string `envconfig:"REDIS_USERNAME"` + RedisDB int `envconfig:"REDIS_DB"` + RedisUseSsl bool `envconfig:"REDIS_USE_SSL"` + RedisSSLCertReqs string `envconfig:"REDIS_SSL_CERT_REQS"` + RedisSSLCACerts string `envconfig:"REDIS_SSL_CA_CERTS"` // redis sentinel RedisUseSentinel bool `envconfig:"REDIS_USE_SENTINEL"` @@ -281,6 +287,54 @@ func (c *Config) GetLocalRuntimeMaxBufferSize() int { return c.PluginRuntimeMaxBufferSize } +// RedisTLSConfig builds a *tls.Config for Redis based on envs. +func (c *Config) RedisTLSConfig() (*tls.Config, error) { + if !c.RedisUseSsl { + return nil, nil + } + + tlsConf := &tls.Config{ + MinVersion: tls.VersionTLS12, + } + + // Load custom CA certificates if provided + if strings.TrimSpace(c.RedisSSLCACerts) != "" { + pem, err := os.ReadFile(c.RedisSSLCACerts) + if err != nil { + return nil, fmt.Errorf("read REDIS_SSL_CA_CERTS: %w", err) + } + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(pem) { + return nil, fmt.Errorf("failed to append CA certs from %s", c.RedisSSLCACerts) + } + tlsConf.RootCAs = pool + } + + // Configure certificate verification based on REDIS_SSL_CERT_REQS + certReqs := strings.ToUpper(strings.TrimSpace(c.RedisSSLCertReqs)) + switch certReqs { + case "CERT_NONE": + // Skip all certificate verification (insecure) + tlsConf.InsecureSkipVerify = true + case "CERT_OPTIONAL", "CERT_REQUIRED", "": + // Require valid certificate verification (default and most secure) + // CERT_OPTIONAL is treated as CERT_REQUIRED for client-side TLS, + // as servers almost always present certificates and the client's + // choice is whether to validate them or not + tlsConf.InsecureSkipVerify = false + + // Require CA certs to be explicitly provided when CERT_REQUIRED is set + if certReqs == "CERT_REQUIRED" && strings.TrimSpace(c.RedisSSLCACerts) == "" { + return nil, fmt.Errorf("REDIS_SSL_CA_CERTS must be provided when REDIS_SSL_CERT_REQS is set to CERT_REQUIRED") + } + default: + // Invalid value - return an error instead of silently defaulting + return nil, fmt.Errorf("invalid REDIS_SSL_CERT_REQS value: %s (valid options: CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED)", certReqs) + } + + return tlsConf, nil +} + type PlatformType string const ( diff --git a/internal/types/app/config_test.go b/internal/types/app/config_test.go new file mode 100644 index 000000000..3efc246f0 --- /dev/null +++ b/internal/types/app/config_test.go @@ -0,0 +1,406 @@ +package app + +import ( + "crypto/tls" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRedisTLSConfig(t *testing.T) { + tests := []struct { + name string + config Config + setupCA bool + expectedTLS bool + expectedError bool + errorContains string + validateConfig func(t *testing.T, tlsConfig *tls.Config) + }{ + { + name: "SSL disabled - returns nil", + config: Config{ + RedisUseSsl: false, + }, + expectedTLS: false, + expectedError: false, + validateConfig: func(t *testing.T, tlsConfig *tls.Config) { + assert.Nil(t, tlsConfig) + }, + }, + { + name: "SSL enabled with default CERT_REQUIRED", + config: Config{ + RedisUseSsl: true, + RedisSSLCertReqs: "", + }, + expectedTLS: true, + expectedError: false, + validateConfig: func(t *testing.T, tlsConfig *tls.Config) { + assert.NotNil(t, tlsConfig) + assert.Equal(t, uint16(tls.VersionTLS12), tlsConfig.MinVersion) + assert.False(t, tlsConfig.InsecureSkipVerify) + assert.Nil(t, tlsConfig.RootCAs) + }, + }, + { + name: "SSL enabled with CERT_REQUIRED explicit and CA cert provided", + config: Config{ + RedisUseSsl: true, + RedisSSLCertReqs: "CERT_REQUIRED", + }, + setupCA: true, + expectedTLS: true, + expectedError: false, + validateConfig: func(t *testing.T, tlsConfig *tls.Config) { + assert.NotNil(t, tlsConfig) + assert.Equal(t, uint16(tls.VersionTLS12), tlsConfig.MinVersion) + assert.False(t, tlsConfig.InsecureSkipVerify) + assert.NotNil(t, tlsConfig.RootCAs) + }, + }, + { + name: "SSL enabled with CERT_REQUIRED lowercase and CA cert provided", + config: Config{ + RedisUseSsl: true, + RedisSSLCertReqs: "cert_required", + }, + setupCA: true, + expectedTLS: true, + expectedError: false, + validateConfig: func(t *testing.T, tlsConfig *tls.Config) { + assert.NotNil(t, tlsConfig) + assert.False(t, tlsConfig.InsecureSkipVerify) + assert.NotNil(t, tlsConfig.RootCAs) + }, + }, + { + name: "SSL enabled with CERT_OPTIONAL", + config: Config{ + RedisUseSsl: true, + RedisSSLCertReqs: "CERT_OPTIONAL", + }, + expectedTLS: true, + expectedError: false, + validateConfig: func(t *testing.T, tlsConfig *tls.Config) { + assert.NotNil(t, tlsConfig) + assert.Equal(t, uint16(tls.VersionTLS12), tlsConfig.MinVersion) + assert.False(t, tlsConfig.InsecureSkipVerify) + }, + }, + { + name: "SSL enabled with CERT_NONE", + config: Config{ + RedisUseSsl: true, + RedisSSLCertReqs: "CERT_NONE", + }, + expectedTLS: true, + expectedError: false, + validateConfig: func(t *testing.T, tlsConfig *tls.Config) { + assert.NotNil(t, tlsConfig) + assert.Equal(t, uint16(tls.VersionTLS12), tlsConfig.MinVersion) + assert.True(t, tlsConfig.InsecureSkipVerify) + }, + }, + { + name: "SSL enabled with CERT_NONE lowercase", + config: Config{ + RedisUseSsl: true, + RedisSSLCertReqs: "cert_none", + }, + expectedTLS: true, + expectedError: false, + validateConfig: func(t *testing.T, tlsConfig *tls.Config) { + assert.NotNil(t, tlsConfig) + assert.True(t, tlsConfig.InsecureSkipVerify) + }, + }, + { + name: "SSL enabled with whitespace in CERT_REQS", + config: Config{ + RedisUseSsl: true, + RedisSSLCertReqs: " CERT_REQUIRED ", + }, + setupCA: true, + expectedTLS: true, + expectedError: false, + validateConfig: func(t *testing.T, tlsConfig *tls.Config) { + assert.NotNil(t, tlsConfig) + assert.False(t, tlsConfig.InsecureSkipVerify) + assert.NotNil(t, tlsConfig.RootCAs) + }, + }, + { + name: "SSL enabled with invalid CERT_REQS value", + config: Config{ + RedisUseSsl: true, + RedisSSLCertReqs: "INVALID_VALUE", + }, + expectedTLS: false, + expectedError: true, + errorContains: "invalid REDIS_SSL_CERT_REQS value", + validateConfig: func(t *testing.T, tlsConfig *tls.Config) { + assert.Nil(t, tlsConfig) + }, + }, + { + name: "SSL enabled with custom CA certificate", + config: Config{ + RedisUseSsl: true, + RedisSSLCertReqs: "CERT_REQUIRED", + }, + setupCA: true, + expectedTLS: true, + expectedError: false, + validateConfig: func(t *testing.T, tlsConfig *tls.Config) { + assert.NotNil(t, tlsConfig) + assert.NotNil(t, tlsConfig.RootCAs) + assert.False(t, tlsConfig.InsecureSkipVerify) + }, + }, + { + name: "SSL enabled with non-existent CA certificate file", + config: Config{ + RedisUseSsl: true, + RedisSSLCertReqs: "CERT_REQUIRED", + RedisSSLCACerts: "/nonexistent/ca.crt", + }, + expectedTLS: false, + expectedError: true, + errorContains: "read REDIS_SSL_CA_CERTS", + validateConfig: func(t *testing.T, tlsConfig *tls.Config) { + assert.Nil(t, tlsConfig) + }, + }, + { + name: "SSL enabled with CERT_REQUIRED but no CA certificate - should fail", + config: Config{ + RedisUseSsl: true, + RedisSSLCertReqs: "CERT_REQUIRED", + RedisSSLCACerts: "", + }, + expectedTLS: false, + expectedError: true, + errorContains: "REDIS_SSL_CA_CERTS must be provided when REDIS_SSL_CERT_REQS is set to CERT_REQUIRED", + validateConfig: func(t *testing.T, tlsConfig *tls.Config) { + assert.Nil(t, tlsConfig) + }, + }, + { + name: "SSL enabled with CERT_REQUIRED and whitespace-only CA certificate - should fail", + config: Config{ + RedisUseSsl: true, + RedisSSLCertReqs: "CERT_REQUIRED", + RedisSSLCACerts: " ", + }, + expectedTLS: false, + expectedError: true, + errorContains: "REDIS_SSL_CA_CERTS must be provided when REDIS_SSL_CERT_REQS is set to CERT_REQUIRED", + validateConfig: func(t *testing.T, tlsConfig *tls.Config) { + assert.Nil(t, tlsConfig) + }, + }, + { + name: "SSL enabled with CERT_OPTIONAL and no CA certificate - should succeed (uses system CAs)", + config: Config{ + RedisUseSsl: true, + RedisSSLCertReqs: "CERT_OPTIONAL", + RedisSSLCACerts: "", + }, + expectedTLS: true, + expectedError: false, + validateConfig: func(t *testing.T, tlsConfig *tls.Config) { + assert.NotNil(t, tlsConfig) + assert.Nil(t, tlsConfig.RootCAs) // RootCAs is nil, will use system default + assert.False(t, tlsConfig.InsecureSkipVerify) + }, + }, + { + name: "SSL enabled with empty string (defaults to CERT_REQUIRED) and no CA certificate - should succeed (uses system CAs)", + config: Config{ + RedisUseSsl: true, + RedisSSLCertReqs: "", + RedisSSLCACerts: "", + }, + expectedTLS: true, + expectedError: false, + validateConfig: func(t *testing.T, tlsConfig *tls.Config) { + assert.NotNil(t, tlsConfig) + assert.Nil(t, tlsConfig.RootCAs) // Empty string case allows system CAs + assert.False(t, tlsConfig.InsecureSkipVerify) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Setup CA certificate if needed + if tt.setupCA { + tempDir, err := os.MkdirTemp("", "redis-tls-test-*") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + caFile := filepath.Join(tempDir, "ca.crt") + err = os.WriteFile(caFile, []byte(testCACert), 0644) + require.NoError(t, err) + + tt.config.RedisSSLCACerts = caFile + } + + // Call RedisTLSConfig + tlsConfig, err := tt.config.RedisTLSConfig() + + // Check error + if tt.expectedError { + assert.Error(t, err) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + } else { + assert.NoError(t, err) + } + + // Check TLS config existence + if tt.expectedTLS { + assert.NotNil(t, tlsConfig) + } + + // Run custom validation + if tt.validateConfig != nil { + tt.validateConfig(t, tlsConfig) + } + }) + } +} + +func TestRedisTLSConfigWithInvalidCAContent(t *testing.T) { + tempDir, err := os.MkdirTemp("", "redis-tls-test-*") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + caFile := filepath.Join(tempDir, "invalid-ca.crt") + err = os.WriteFile(caFile, []byte("invalid certificate content"), 0644) + require.NoError(t, err) + + config := Config{ + RedisUseSsl: true, + RedisSSLCertReqs: "CERT_REQUIRED", + RedisSSLCACerts: caFile, + } + + tlsConfig, err := config.RedisTLSConfig() + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to append CA certs") + assert.Nil(t, tlsConfig) +} + +func TestRedisTLSConfigCombinations(t *testing.T) { + // Test that MinVersion is always set to TLS 1.2 when SSL is enabled + tests := []struct { + name string + certReqs string + setupCA bool + }{ + {"CERT_NONE", "CERT_NONE", false}, + {"CERT_OPTIONAL", "CERT_OPTIONAL", false}, + {"CERT_REQUIRED", "CERT_REQUIRED", true}, // CERT_REQUIRED now requires CA cert + {"Empty (default)", "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := Config{ + RedisUseSsl: true, + RedisSSLCertReqs: tt.certReqs, + } + + // Setup CA certificate if needed + if tt.setupCA { + tempDir, err := os.MkdirTemp("", "redis-tls-test-*") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + caFile := filepath.Join(tempDir, "ca.crt") + err = os.WriteFile(caFile, []byte(testCACert), 0644) + require.NoError(t, err) + + config.RedisSSLCACerts = caFile + } + + tlsConfig, err := config.RedisTLSConfig() + require.NoError(t, err) + require.NotNil(t, tlsConfig) + assert.Equal(t, uint16(tls.VersionTLS12), tlsConfig.MinVersion, + "MinVersion should always be TLS 1.2") + }) + } +} + +func TestRedisTLSConfigCaseInsensitivity(t *testing.T) { + cases := []struct { + certReqs string + needsCA bool + }{ + {"CERT_NONE", false}, + {"cert_none", false}, + {"Cert_None", false}, + {"CERT_OPTIONAL", false}, + {"cert_optional", false}, + {"Cert_Optional", false}, + {"CERT_REQUIRED", true}, + {"cert_required", true}, + {"Cert_Required", true}, + } + + for _, tc := range cases { + t.Run(tc.certReqs, func(t *testing.T) { + config := Config{ + RedisUseSsl: true, + RedisSSLCertReqs: tc.certReqs, + } + + // Setup CA certificate if needed for CERT_REQUIRED variants + if tc.needsCA { + tempDir, err := os.MkdirTemp("", "redis-tls-test-*") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + caFile := filepath.Join(tempDir, "ca.crt") + err = os.WriteFile(caFile, []byte(testCACert), 0644) + require.NoError(t, err) + + config.RedisSSLCACerts = caFile + } + + tlsConfig, err := config.RedisTLSConfig() + assert.NoError(t, err, "Should accept case-insensitive cert requirements") + assert.NotNil(t, tlsConfig) + }) + } +} + +// Test CA certificate in PEM format (self-signed certificate for testing) +const testCACert = `-----BEGIN CERTIFICATE----- +MIIDXTCCAkWgAwIBAgIJAKL0UG+mRqqSMA0GCSqGSIb3DQEBCwUAMEUxCzAJBgNV +BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX +aWRnaXRzIFB0eSBMdGQwHhcNMTcwMjIyMDUwNzQ4WhcNMjcwMjIwMDUwNzQ4WjBF +MQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50 +ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB +CgKCAQEAyVuOPQaVxQjBBg7CYBaKzVnfJKG6iFUvuQfJqXL9BKE8bTSASPrHGEFL +WLTh7MvYhE9nMLvPB7FzJFdI5K5pxWRmIFxM6pxEPGvGnF0Bc+cGSN2UFHQZW0Rc +qxZJu5Hbv9YMsGHG+VPx8mWfD8LDHD+M5lKVrNWfXeNXKXlPfnqP8N0vQAUv7z5Y +y0N0OJCcZqW3nFqGNAFLVqL3MzLz+2thqBKs3vG2VQ0NI0aL9T4eqN1qQXqIHWnQ +tLQgLhNCJVQxcLEu2KHyBXUJrI8FnxVAoOyPKq5wJjVPEwqBp5HWpqnNKO6eFXKE +l3BZshBqQ5W8Q6K5LQ0Hwd0qCtSxPwIDAQABo1AwTjAdBgNVHQ4EFgQU8Y1j8vPz +aR3JMXdDrLK9LeV0RjwwHwYDVR0jBBgwFoAU8Y1j8vPzaR3JMXdDrLK9LeV0Rjww +DAYDVR0TBAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAeYr0R8wCLxq0ySl3EQ8G +bvF/VLgLCXlvVKEiKwrSkPZTzSmfOJcQfqCAH3pJjVqDOTAZ7H0cV8CzZVpK7q3U +VPl9D5p9hF0VJ3LcMhNLXNhp5C3WBBTqCXF5FQMxgNNwdlJW0cJLrXPG8D8yXhPc +M/qXYqd7K1Q4RiXBNLxSGEqPj5mZnVKZ7JQTKCYqF5uHVx3y8c7gK3nYaXNQfZFa +N8Qs9CZmKVFvJ4KU6nOaW5X8gTCrHvBFMFaQcKKpCmWZfLnPJZMdgZXvxhx5lXXU +9nKKqk7sKB4D6LqHKQ1qRx9HJJVP5LxHMYGpxGnxXNLMaCPjLMxVJpQSZLnVfj5Y +dQ== +-----END CERTIFICATE-----` diff --git a/pkg/utils/cache/redis.go b/pkg/utils/cache/redis.go index a84351517..288c5239a 100644 --- a/pkg/utils/cache/redis.go +++ b/pkg/utils/cache/redis.go @@ -21,7 +21,7 @@ var ( ErrNotFound = errors.New("key not found") ) -func getRedisOptions(addr, username, password string, useSsl bool, db int) *redis.Options { +func getRedisOptions(addr, username, password string, useSsl bool, db int, tlsConf *tls.Config) *redis.Options { opts := &redis.Options{ Addr: addr, Username: username, @@ -29,13 +29,14 @@ func getRedisOptions(addr, username, password string, useSsl bool, db int) *redi DB: db, } if useSsl { - opts.TLSConfig = &tls.Config{} + // The provided tlsConf is guaranteed to be non-nil when useSsl is true. + opts.TLSConfig = tlsConf } return opts } -func InitRedisClient(addr, username, password string, useSsl bool, db int) error { - opts := getRedisOptions(addr, username, password, useSsl, db) +func InitRedisClient(addr, username, password string, useSsl bool, db int, tlsConf *tls.Config) error { + opts := getRedisOptions(addr, username, password, useSsl, db, tlsConf) client = redis.NewClient(opts) if _, err := client.Ping(ctx).Result(); err != nil { @@ -45,7 +46,14 @@ func InitRedisClient(addr, username, password string, useSsl bool, db int) error return nil } -func InitRedisSentinelClient(sentinels []string, masterName, username, password, sentinelUsername, sentinelPassword string, useSsl bool, db int, socketTimeout float64) error { +func InitRedisSentinelClient( + sentinels []string, + masterName, username, password, sentinelUsername, sentinelPassword string, + useSsl bool, + db int, + socketTimeout float64, + tlsConf *tls.Config, +) error { opts := &redis.FailoverOptions{ MasterName: masterName, SentinelAddrs: sentinels, @@ -57,7 +65,8 @@ func InitRedisSentinelClient(sentinels []string, masterName, username, password, } if useSsl { - opts.TLSConfig = &tls.Config{} + // The provided tlsConf is guaranteed to be non-nil when useSsl is true. + opts.TLSConfig = tlsConf } if socketTimeout > 0 { @@ -366,13 +375,15 @@ func ScanMap[V any](key string, match string, context ...redis.Cmdable) (map[str result := make(map[string]V) - ScanMapAsync[V](key, match, func(m map[string]V) error { + if err := ScanMapAsync[V](key, match, func(m map[string]V) error { for k, v := range m { result[k] = v } return nil - }) + }, context...); err != nil { + return nil, err + } return result, nil } diff --git a/pkg/utils/cache/redis_auto_type_test.go b/pkg/utils/cache/redis_auto_type_test.go index 94c766716..19a491914 100644 --- a/pkg/utils/cache/redis_auto_type_test.go +++ b/pkg/utils/cache/redis_auto_type_test.go @@ -10,7 +10,7 @@ type TestAutoTypeStruct struct { } func TestAutoType(t *testing.T) { - if err := InitRedisClient("127.0.0.1:6379", "", "difyai123456", false, 0); err != nil { + if err := InitRedisClient("127.0.0.1:6379", "", "difyai123456", false, 0, nil); err != nil { t.Fatal(err) } defer Close() @@ -35,7 +35,7 @@ func TestAutoType(t *testing.T) { } func TestAutoTypeWithGetter(t *testing.T) { - if err := InitRedisClient("127.0.0.1:6379", "", "difyai123456", false, 0); err != nil { + if err := InitRedisClient("127.0.0.1:6379", "", "difyai123456", false, 0, nil); err != nil { t.Fatal(err) } defer Close() diff --git a/pkg/utils/cache/redis_test.go b/pkg/utils/cache/redis_test.go index fae22d56b..ae3d83879 100644 --- a/pkg/utils/cache/redis_test.go +++ b/pkg/utils/cache/redis_test.go @@ -18,7 +18,7 @@ const ( ) func getRedisConnection() error { - return InitRedisClient("0.0.0.0:6379", "", "difyai123456", false, 0) + return InitRedisClient("0.0.0.0:6379", "", "difyai123456", false, 0, nil) } func TestRedisConnection(t *testing.T) { @@ -272,13 +272,13 @@ func TestRedisP2ARedis(t *testing.T) { } func TestGetRedisOptions(t *testing.T) { - opts := getRedisOptions("dummy:6379", "", "password", false, 0) + opts := getRedisOptions("dummy:6379", "", "password", false, 0, nil) if opts.TLSConfig != nil { t.Errorf("TLSConfig should not be set") return } - opts = getRedisOptions("dummy:6379", "", "password", true, 0) + opts = getRedisOptions("dummy:6379", "", "password", true, 0, nil) if opts.TLSConfig == nil { t.Errorf("TLSConfig should be set") return @@ -286,7 +286,7 @@ func TestGetRedisOptions(t *testing.T) { } func TestSetAndGet(t *testing.T) { - if err := InitRedisClient("127.0.0.1:6379", "", "difyai123456", false, 0); err != nil { + if err := InitRedisClient("127.0.0.1:6379", "", "difyai123456", false, 0, nil); err != nil { t.Fatal(err) } defer Close() @@ -318,7 +318,7 @@ func TestSetAndGet(t *testing.T) { } func TestLock(t *testing.T) { - if err := InitRedisClient("127.0.0.1:6379", "", "difyai123456", false, 0); err != nil { + if err := InitRedisClient("127.0.0.1:6379", "", "difyai123456", false, 0, nil); err != nil { t.Fatal(err) } defer Close()