Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions pkg/api/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ type Controller struct {
HTPasswd *HTPasswd
HTPasswdWatcher *HTPasswdWatcher
LDAPClient *LDAPClient
CertReloader *CertReloader
taskScheduler *scheduler.Scheduler
Healthz *common.Healthz
// runtime params
Expand Down Expand Up @@ -204,6 +205,18 @@ func (c *Controller) Run() error {

tlsConfig := c.Config.CopyTLSConfig()
if tlsConfig != nil && tlsConfig.Key != "" && tlsConfig.Cert != "" {
// Create certificate reloader for automatic TLS certificate updates
certReloader, err := NewCertReloader(tlsConfig.Cert, tlsConfig.Key, c.Log)
if err != nil {
c.Log.Error().Err(err).Str("cert", tlsConfig.Cert).Str("key", tlsConfig.Key).
Msg("failed to load TLS certificates")

return err
}

// Store the CertReloader so it can be closed during shutdown
c.CertReloader = certReloader

// These are the same as the cipher suites in defaultCipherSuitesFIPS for TLS 1.2
// see https://cs.opensource.google/go/go/+/refs/tags/go1.24.9:src/crypto/tls/defaults.go;l=123
// Note: Order doesn't matter - Go 1.17+ automatically orders cipher suites based on
Expand Down Expand Up @@ -239,7 +252,8 @@ func (c *Controller) Run() error {
CipherSuites: cipherSuites,
CurvePreferences: curvePreferences,
// PreferServerCipherSuites is ignored in Go 1.17+ - Go automatically orders cipher suites
MinVersion: tls.VersionTLS12,
MinVersion: tls.VersionTLS12,
GetCertificate: certReloader.GetCertificateFunc(),
}

if tlsConfig.CACert != "" {
Expand All @@ -266,7 +280,8 @@ func (c *Controller) Run() error {

c.Healthz.Ready()

return server.ServeTLS(listener, tlsConfig.Cert, tlsConfig.Key)
// Pass empty strings to ServeTLS - certificates will be loaded via GetCertificate callback
return server.ServeTLS(listener, "", "")
}

c.Healthz.Ready()
Expand Down Expand Up @@ -484,6 +499,11 @@ func (c *Controller) StopBackgroundTasks() {
if c.HTPasswdWatcher != nil {
_ = c.HTPasswdWatcher.Close()
}

// Close CertReloader to prevent resource leaks
if c.CertReloader != nil {
_ = c.CertReloader.Close()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is closed in StopBackgroundTasks() doesn't it need to start in StartBackgroundTasks() for simetry?

There are tests which call Run() and Shutdown() multiple times on the same object, I think I fixed the background tasks issues when I refactored the HTPasswdWatcher, but let's not risk it.

}
}

func (c *Controller) StartBackgroundTasks() {
Expand Down
280 changes: 280 additions & 0 deletions pkg/api/tlscert.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
package api

import (
"crypto/tls"
"os"
"path/filepath"
"sync"
"time"

"github.com/fsnotify/fsnotify"

"zotregistry.dev/zot/v2/pkg/log"
)

const (
// certCheckCacheDuration is the minimum time between file stat checks when fsnotify is unavailable.
// This prevents excessive file system calls during high TLS handshake rates.
certCheckCacheDuration = 1 * time.Second
)

// CertReloader handles automatic reloading of TLS certificates without downtime.
// It monitors certificate and key files for changes and reloads them dynamically
// using a GetCertificate callback in tls.Config.
type CertReloader struct {
certMu sync.RWMutex
cert *tls.Certificate
certPath string
keyPath string
certMod time.Time
keyMod time.Time
log log.Logger
watcher *fsnotify.Watcher
reloadMu sync.Mutex // Prevents concurrent reload operations
lastCheck time.Time
checkCache time.Duration // Minimum time between file stat checks
stopWatcher chan struct{}
closeOnce sync.Once // Ensures Close() can be called multiple times safely
}

// NewCertReloader creates a new certificate reloader and loads the initial certificate.
// It starts an fsnotify watcher to monitor certificate file changes.
func NewCertReloader(certPath, keyPath string, logger log.Logger) (*CertReloader, error) {
reloader := &CertReloader{
certPath: certPath,
keyPath: keyPath,
log: logger,
checkCache: certCheckCacheDuration,
stopWatcher: make(chan struct{}),
}

if err := reloader.reload(); err != nil {
return nil, err
}

// Start fsnotify watcher in background
if err := reloader.startWatcher(); err != nil {
// Log warning but don't fail - we'll fall back to periodic checking
logger.Warn().Err(err).Msg("failed to start fsnotify watcher, falling back to periodic checking")
}

// NOTE: Do not add initialization that can fail after this point without ensuring
// the watcher is stopped (e.g., by calling Close on error), otherwise the
// watchLoop goroutine started by startWatcher could be leaked.

return reloader, nil
}

// Close stops the file watcher and releases resources.
// This method is safe to call multiple times.
func (cr *CertReloader) Close() error {
var err error
cr.closeOnce.Do(func() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work? If I remember correctly, we have tests calling Run() multiple times on the same controller. That would mean starting the watcher multiple times. This code closes it just once.

if cr.stopWatcher != nil {
close(cr.stopWatcher)
}

if cr.watcher != nil {
err = cr.watcher.Close()
}
})

return err
}

// startWatcher initializes the fsnotify watcher for certificate files.
func (cr *CertReloader) startWatcher() error {
watcher, err := fsnotify.NewWatcher()
if err != nil {
return err
}

cr.watcher = watcher

// Watch the directory containing the certificate files
// This is more reliable than watching files directly, especially for atomic file updates
certDir := filepath.Dir(cr.certPath)
keyDir := filepath.Dir(cr.keyPath)

if err := watcher.Add(certDir); err != nil {
watcher.Close()

return err
}

// If cert and key are in different directories, watch both
if certDir != keyDir {
if err := watcher.Add(keyDir); err != nil {
watcher.Close()

return err
}
}

// Start goroutine to handle file system events
go cr.watchLoop()

return nil
}

// watchLoop handles file system events from fsnotify.
func (cr *CertReloader) watchLoop() {
for {
select {
case <-cr.stopWatcher:
return
case event, ok := <-cr.watcher.Events:
if !ok {
return
}

// Check if the event is for our certificate or key files
if event.Name == cr.certPath || event.Name == cr.keyPath {
// Only process write and create events
if event.Op&(fsnotify.Write|fsnotify.Create) != 0 {
cr.log.Debug().Str("file", event.Name).Str("op", event.Op.String()).
Msg("certificate file change detected")

// Try to reload the certificate
cr.tryReload()
}
}
case err, ok := <-cr.watcher.Errors:
if !ok {
return
}
cr.log.Warn().Err(err).Msg("fsnotify watcher error")
}
}
}

// tryReload attempts to reload certificates with proper concurrency control.
func (cr *CertReloader) tryReload() {
// Use mutex to ensure only one reload happens at a time
// This prevents race condition where multiple goroutines detect changes simultaneously
cr.reloadMu.Lock()
defer cr.reloadMu.Unlock()

if err := cr.reload(); err != nil {
cr.log.Warn().Err(err).Str("cert", cr.certPath).Str("key", cr.keyPath).
Msg("failed to reload TLS certificates")
} else {
cr.log.Info().Str("cert", cr.certPath).Str("key", cr.keyPath).
Msg("TLS certificates reloaded successfully")
}
}

// reload loads the certificate and key from disk and updates the internal certificate.
func (cr *CertReloader) reload() error {
// Get file modification times
certInfo, err := os.Stat(cr.certPath)
if err != nil {
return err
}

keyInfo, err := os.Stat(cr.keyPath)
if err != nil {
return err
}

certMod := certInfo.ModTime()
keyMod := keyInfo.ModTime()

// Load the certificate
newCert, err := tls.LoadX509KeyPair(cr.certPath, cr.keyPath)
if err != nil {
return err
}

// Update the certificate and modification times
cr.certMu.Lock()
defer cr.certMu.Unlock()

cr.cert = &newCert
cr.certMod = certMod
cr.keyMod = keyMod

return nil
}

// maybeReload checks if the certificate files have been modified and reloads them if necessary.
// This is used as a fallback when fsnotify is not available or fails.
// Uses time-based caching to avoid excessive file system calls.
func (cr *CertReloader) maybeReload() error {
// Use write lock for both check and update to prevent race conditions
// While less efficient than RLock+Lock upgrade, this ensures only one goroutine
// updates lastCheck at a time, preventing multiple goroutines from bypassing
// the cache check simultaneously. Since we have a 1-second cache, this lock
// is acquired at most once per second, making the performance impact acceptable.
cr.certMu.Lock()
if time.Since(cr.lastCheck) < cr.checkCache {
// Recently checked, skip stat calls
cr.certMu.Unlock()

return nil
}
// Update last check time within the same critical section as the cache check
cr.lastCheck = time.Now()
cr.certMu.Unlock()

// Check cert file modification time
certInfo, err := os.Stat(cr.certPath)
if err != nil {
return err
}

keyInfo, err := os.Stat(cr.keyPath)
if err != nil {
return err
}

certMod := certInfo.ModTime()
keyMod := keyInfo.ModTime()

// Check if files have been modified
cr.certMu.RLock()
needsReload := certMod.After(cr.certMod) || keyMod.After(cr.keyMod)
cr.certMu.RUnlock()

if needsReload {
// Use reloadMu to prevent concurrent reload operations
cr.reloadMu.Lock()
defer cr.reloadMu.Unlock()

// Double-check after acquiring lock - another goroutine might have already reloaded
cr.certMu.RLock()
stillNeedsReload := certMod.After(cr.certMod) || keyMod.After(cr.keyMod)
cr.certMu.RUnlock()

if stillNeedsReload {
if err := cr.reload(); err != nil {
cr.log.Warn().Err(err).Str("cert", cr.certPath).Str("key", cr.keyPath).
Msg("failed to reload TLS certificates")

return err
}

cr.log.Info().Str("cert", cr.certPath).Str("key", cr.keyPath).
Msg("TLS certificates reloaded successfully")
}
}

return nil
}

// GetCertificateFunc returns a function that can be used as tls.Config.GetCertificate.
// This function checks for certificate updates on each TLS handshake and reloads if necessary.
// If fsnotify watcher is active, this only performs time-cached checks as a fallback.
func (cr *CertReloader) GetCertificateFunc() func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
// Try to reload the certificate if it has changed
// This is a fallback mechanism when fsnotify is not available
// Errors are logged but ignored to maintain availability with existing certificate
_ = cr.maybeReload()

cr.certMu.RLock()
defer cr.certMu.RUnlock()

return cr.cert, nil
}
}
Loading
Loading