diff --git a/.env.example b/.env.example index 5e077b73e..1b3fdbc94 100644 --- a/.env.example +++ b/.env.example @@ -79,6 +79,13 @@ REDIS_PORT=6379 REDIS_PASSWORD=difyai123456 REDIS_DB=0 +# redis TLS +REDIS_USE_SSL=false +## CERT_NONE | CERT_OPTIONAL | CERT_REQUIRED +#REDIS_SSL_CERT_REQS=CERT_REQUIRED +## Optional custom CA bundle (PEM) if not using public CAs +#REDIS_SSL_CA_CERTS= + # Whether to use Redis Sentinel mode. # If set to true, the application will automatically discover and connect to the master node through Sentinel. REDIS_USE_SENTINEL=false diff --git a/internal/core/plugin_manager/manager.go b/internal/core/plugin_manager/manager.go index 3d2d66ef3..30676a15c 100644 --- a/internal/core/plugin_manager/manager.go +++ b/internal/core/plugin_manager/manager.go @@ -113,6 +113,12 @@ func (p *PluginManager) GetAsset(id string) ([]byte, error) { func (p *PluginManager) Launch(configuration *app.Config) { log.Info("start plugin manager daemon...") + // Build TLS config for Redis (nil when RedisUseSsl=false) + tlsConf, err := configuration.RedisTLSConfig() + if err != nil { + log.Panic("invalid Redis TLS config: %s", err.Error()) + } + // init redis client if configuration.RedisUseSentinel { // use Redis Sentinel @@ -127,6 +133,7 @@ func (p *PluginManager) Launch(configuration *app.Config) { configuration.RedisUseSsl, configuration.RedisDB, configuration.RedisSentinelSocketTimeout, + tlsConf, // pass TLS to cache initializer ); err != nil { log.Panic("init redis sentinel client failed: %s", err.Error()) } @@ -137,6 +144,7 @@ func (p *PluginManager) Launch(configuration *app.Config) { configuration.RedisPass, configuration.RedisUseSsl, configuration.RedisDB, + tlsConf, // pass TLS to cache initializer ); err != nil { log.Panic("init redis client failed: %s", err.Error()) } diff --git a/internal/types/app/config.go b/internal/types/app/config.go index 028872a5f..e5580c673 100644 --- a/internal/types/app/config.go +++ b/internal/types/app/config.go @@ -1,7 +1,11 @@ package app import ( + "crypto/tls" + "crypto/x509" "fmt" + "os" + "strings" "github.com/go-playground/validator/v10" ) @@ -109,6 +113,10 @@ type Config struct { RedisUseSsl bool `envconfig:"REDIS_USE_SSL"` RedisDB int `envconfig:"REDIS_DB"` + // redis TLS extras + RedisSSLCertReqs string `envconfig:"REDIS_SSL_CERT_REQS" default:"CERT_REQUIRED"` + RedisSSLCACerts string `envconfig:"REDIS_SSL_CA_CERTS"` + // redis sentinel RedisUseSentinel bool `envconfig:"REDIS_USE_SENTINEL"` RedisSentinels string `envconfig:"REDIS_SENTINELS"` @@ -253,6 +261,42 @@ func (c *Config) Validate() error { return nil } +// RedisTLSConfig builds a *tls.Config for Redis based on envs. +func (c *Config) RedisTLSConfig() (*tls.Config, error) { + if !c.RedisUseSsl { + return nil, nil + } + + tlsConf := &tls.Config{ + MinVersion: tls.VersionTLS12, + } + + if strings.TrimSpace(c.RedisSSLCACerts) != "" { + pem, err := os.ReadFile(c.RedisSSLCACerts) + if err != nil { + return nil, fmt.Errorf("read REDIS_SSL_CA_CERTS: %w", err) + } + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(pem) { + return nil, fmt.Errorf("failed to append CA certs from %s", c.RedisSSLCACerts) + } + tlsConf.RootCAs = pool + } + + switch strings.ToUpper(strings.TrimSpace(c.RedisSSLCertReqs)) { + case "CERT_NONE": + tlsConf.InsecureSkipVerify = true + case "CERT_OPTIONAL": + tlsConf.InsecureSkipVerify = false + case "CERT_REQUIRED", "": + tlsConf.InsecureSkipVerify = false + default: + tlsConf.InsecureSkipVerify = false + } + + return tlsConf, nil +} + type PlatformType string const ( diff --git a/internal/utils/cache/redis.go b/internal/utils/cache/redis.go index badc279a9..f0145ad83 100644 --- a/internal/utils/cache/redis.go +++ b/internal/utils/cache/redis.go @@ -20,7 +20,7 @@ var ( ErrNotFound = errors.New("key not found") ) -func getRedisOptions(addr, username, password string, useSsl bool, db int) *redis.Options { +func getRedisOptions(addr, username, password string, useSsl bool, db int, tlsConf *tls.Config) *redis.Options { opts := &redis.Options{ Addr: addr, Username: username, @@ -28,23 +28,34 @@ func getRedisOptions(addr, username, password string, useSsl bool, db int) *redi DB: db, } if useSsl { - opts.TLSConfig = &tls.Config{} + // Use provided TLS config (encodes CERT_NONE / REQUIRED policy). If nil, default to system roots. + if tlsConf != nil { + opts.TLSConfig = tlsConf + } else { + opts.TLSConfig = &tls.Config{} + } } return opts } -func InitRedisClient(addr, username, password string, useSsl bool, db int) error { - opts := getRedisOptions(addr, username, password, useSsl, db) +func InitRedisClient(addr, username, password string, useSsl bool, db int, tlsConf *tls.Config) error { + opts := getRedisOptions(addr, username, password, useSsl, db, tlsConf) client = redis.NewClient(opts) if _, err := client.Ping(ctx).Result(); err != nil { return err } - return nil } -func InitRedisSentinelClient(sentinels []string, masterName, username, password, sentinelUsername, sentinelPassword string, useSsl bool, db int, socketTimeout float64) error { +func InitRedisSentinelClient( + sentinels []string, + masterName, username, password, sentinelUsername, sentinelPassword string, + useSsl bool, + db int, + socketTimeout float64, + tlsConf *tls.Config, +) error { opts := &redis.FailoverOptions{ MasterName: masterName, SentinelAddrs: sentinels, @@ -56,7 +67,12 @@ func InitRedisSentinelClient(sentinels []string, masterName, username, password, } if useSsl { - opts.TLSConfig = &tls.Config{} + // go-redis v9 uses TLSConfig for both Sentinel discovery and data connections + if tlsConf != nil { + opts.TLSConfig = tlsConf + } else { + opts.TLSConfig = &tls.Config{} + } } if socketTimeout > 0 { @@ -68,7 +84,6 @@ func InitRedisSentinelClient(sentinels []string, masterName, username, password, if _, err := client.Ping(ctx).Result(); err != nil { return err } - return nil } @@ -77,7 +92,6 @@ func Close() error { if client == nil { return ErrDBNotInit } - return client.Close() } @@ -85,7 +99,6 @@ func getCmdable(context ...redis.Cmdable) redis.Cmdable { if len(context) > 0 { return context[0] } - return client } @@ -365,13 +378,12 @@ func ScanMap[V any](key string, match string, context ...redis.Cmdable) (map[str result := make(map[string]V) - ScanMapAsync[V](key, match, func(m map[string]V) error { + _ = ScanMapAsync[V](key, match, func(m map[string]V) error { for k, v := range m { result[k] = v } - return nil - }) + }, context...) return result, nil } @@ -388,7 +400,6 @@ func ScanMapAsync[V any](key string, match string, fn func(map[string]V) error, kvs, newCursor, err := getCmdable(context...). HScan(ctx, serialKey(key), cursor, match, 32). Result() - if err != nil { return err } @@ -399,7 +410,6 @@ func ScanMapAsync[V any](key string, match string, fn func(map[string]V) error, if err != nil { continue } - result[kvs[i]] = value } @@ -410,7 +420,6 @@ func ScanMapAsync[V any](key string, match string, fn func(map[string]V) error, if newCursor == 0 { break } - cursor = newCursor } @@ -533,7 +542,6 @@ func Subscribe[T any](channel string) (<-chan T, func()) { if err != nil { continue } - ch <- v case *redis.Pong: default: