diff --git a/cmd/catalogd/main.go b/cmd/catalogd/main.go index ec7c4946ff..84dbbe30b6 100644 --- a/cmd/catalogd/main.go +++ b/cmd/catalogd/main.go @@ -60,6 +60,7 @@ import ( "github.com/operator-framework/operator-controller/internal/catalogd/webhook" sharedcontrollers "github.com/operator-framework/operator-controller/internal/shared/controllers" fsutil "github.com/operator-framework/operator-controller/internal/shared/util/fs" + httputil "github.com/operator-framework/operator-controller/internal/shared/util/http" imageutil "github.com/operator-framework/operator-controller/internal/shared/util/image" "github.com/operator-framework/operator-controller/internal/shared/util/pullsecretcache" sautil "github.com/operator-framework/operator-controller/internal/shared/util/sa" @@ -291,6 +292,18 @@ func run(ctx context.Context) error { return err } + // This watches the pullCasDir and the SSL_CERT_DIR, and SSL_CERT_FILE for changes + cpwPull, err := httputil.NewCertPoolWatcher(cfg.pullCasDir, ctrl.Log.WithName("pull-ca-pool")) + if err != nil { + setupLog.Error(err, "unable to create pull-ca-pool watcher") + return err + } + cpwPull.Restart(os.Exit) + if err = mgr.Add(cpwPull); err != nil { + setupLog.Error(err, "unable to add pull-ca-pool watcher to manager") + return err + } + if cfg.systemNamespace == "" { cfg.systemNamespace = podNamespace() } diff --git a/cmd/operator-controller/main.go b/cmd/operator-controller/main.go index e52e2cb6c7..2ba4bc9f06 100644 --- a/cmd/operator-controller/main.go +++ b/cmd/operator-controller/main.go @@ -319,9 +319,26 @@ func run() error { return err } - certPoolWatcher, err := httputil.NewCertPoolWatcher(cfg.catalogdCasDir, ctrl.Log.WithName("cert-pool")) + cpwCatalogd, err := httputil.NewCertPoolWatcher(cfg.catalogdCasDir, ctrl.Log.WithName("catalogd-ca-pool")) if err != nil { - setupLog.Error(err, "unable to create CA certificate pool") + setupLog.Error(err, "unable to create catalogd-ca-pool watcher") + return err + } + cpwCatalogd.Restart(os.Exit) + if err = mgr.Add(cpwCatalogd); err != nil { + setupLog.Error(err, "unable to add catalogd-ca-pool watcher to manager") + return err + } + + // This watches the pullCasDir and the SSL_CERT_DIR, and SSL_CERT_FILE for changes + cpwPull, err := httputil.NewCertPoolWatcher(cfg.pullCasDir, ctrl.Log.WithName("pull-ca-pool")) + if err != nil { + setupLog.Error(err, "unable to create pull-ca-pool watcher") + return err + } + cpwPull.Restart(os.Exit) + if err = mgr.Add(cpwPull); err != nil { + setupLog.Error(err, "unable to add pull-ca-pool watcher to manager") return err } @@ -375,7 +392,7 @@ func run() error { } catalogClientBackend := cache.NewFilesystemCache(catalogsCachePath) catalogClient := catalogclient.New(catalogClientBackend, func() (*http.Client, error) { - return httputil.BuildHTTPClient(certPoolWatcher) + return httputil.BuildHTTPClient(cpwCatalogd) }) resolver := &resolve.CatalogResolver{ diff --git a/internal/shared/util/http/certlog.go b/internal/shared/util/http/certlog.go index 60aa0bd198..b772f93f1e 100644 --- a/internal/shared/util/http/certlog.go +++ b/internal/shared/util/http/certlog.go @@ -122,7 +122,16 @@ func logFile(filename, path, action string, log logr.Logger) { log.Error(err, "error in os.ReadFile()", "file", filepath) return } - logPem(data, filename, path, action, log) + if len(data) > 0 { + logPem(data, filename, path, action, log) + return + } + // Indicate that the file is empty + args := []any{"file", filename, "empty", "true"} + if path != "" { + args = append(args, "directory", path) + } + log.V(defaultLogLevel).Info(action, args...) } func logPem(data []byte, filename, path, action string, log logr.Logger) { diff --git a/internal/shared/util/http/certpoolwatcher.go b/internal/shared/util/http/certpoolwatcher.go index 7f95449e98..da5e78ba93 100644 --- a/internal/shared/util/http/certpoolwatcher.go +++ b/internal/shared/util/http/certpoolwatcher.go @@ -1,6 +1,7 @@ package http import ( + "context" "crypto/x509" "fmt" "os" @@ -14,13 +15,15 @@ import ( ) type CertPoolWatcher struct { - generation int - dir string - mx sync.RWMutex - pool *x509.CertPool - log logr.Logger - watcher *fsnotify.Watcher - done chan bool + generation int + dir string + sslCertPaths []string + mx sync.RWMutex + pool *x509.CertPool + log logr.Logger + watcher *fsnotify.Watcher + done chan bool + restart func(int) } // Returns the current CertPool and the generation number @@ -33,77 +36,111 @@ func (cpw *CertPoolWatcher) Get() (*x509.CertPool, int, error) { return cpw.pool.Clone(), cpw.generation, nil } -func (cpw *CertPoolWatcher) Done() { - cpw.done <- true +// Change the restart behavior +func (cpw *CertPoolWatcher) Restart(f func(int)) { + cpw.restart = f } -func NewCertPoolWatcher(caDir string, log logr.Logger) (*CertPoolWatcher, error) { - pool, err := NewCertPool(caDir, log) - if err != nil { - return nil, err +// Indicate that you're done with the CertPoolWatcher so it can terminate +// the watcher go func +func (cpw *CertPoolWatcher) Done() { + if cpw.watcher != nil { + cpw.done <- true } - watcher, err := fsnotify.NewWatcher() +} + +func (cpw *CertPoolWatcher) Start(ctx context.Context) error { + var err error + cpw.pool, err = NewCertPool(cpw.dir, cpw.log) if err != nil { - return nil, err + return err } - // If the SSL_CERT_DIR or SSL_CERT_FILE environment variables are - // specified, this means that we have some control over the system root - // location, thus they may change, thus we should watch those locations. - sslCertDir := os.Getenv("SSL_CERT_DIR") - sslCertFile := os.Getenv("SSL_CERT_FILE") - log.V(defaultLogLevel).Info("SSL environment", "SSL_CERT_DIR", sslCertDir, "SSL_CERT_FILE", sslCertFile) + watchPaths := append(cpw.sslCertPaths, cpw.dir) + watchPaths = slices.DeleteFunc(watchPaths, deleteEmptyStrings) - watchPaths := strings.Split(sslCertDir, ":") - watchPaths = append(watchPaths, caDir, sslCertFile) - watchPaths = slices.DeleteFunc(watchPaths, func(p string) bool { - if p == "" { - return true - } - if _, err := os.Stat(p); err != nil { - return true - } - return false - }) + // Nothing was configured to be watched, which means this is + // using the SystemCertPool, so we still need to no error out + if len(watchPaths) == 0 { + cpw.log.Info("No paths to watch") + return nil + } + + cpw.watcher, err = fsnotify.NewWatcher() + if err != nil { + return err + } for _, p := range watchPaths { - if err := watcher.Add(p); err != nil { - return nil, err + if err := cpw.watcher.Add(p); err != nil { + cpw.watcher.Close() + cpw.watcher = nil + return err } - logPath(p, "watching certificate", log) + logPath(p, "watching certificate", cpw.log) } - cpw := &CertPoolWatcher{ - generation: 1, - dir: caDir, - pool: pool, - log: log, - watcher: watcher, - done: make(chan bool), - } go func() { for { select { - case <-watcher.Events: + case e := <-cpw.watcher.Events: + cpw.checkForRestart(e.Name) cpw.drainEvents() - cpw.update() - case err := <-watcher.Errors: - log.Error(err, "error watching certificate dir") + cpw.update(e.Name) + case err := <-cpw.watcher.Errors: + cpw.log.Error(err, "error watching certificate dir") os.Exit(1) + case <-ctx.Done(): + cpw.Done() case <-cpw.done: - err := watcher.Close() + err := cpw.watcher.Close() if err != nil { - log.Error(err, "error closing watcher") + cpw.log.Error(err, "error closing watcher") } return } } }() + return nil +} + +func NewCertPoolWatcher(caDir string, log logr.Logger) (*CertPoolWatcher, error) { + // If the SSL_CERT_DIR or SSL_CERT_FILE environment variables are + // specified, this means that we have some control over the system root + // location, thus they may change, thus we should watch those locations. + // + // BECAUSE THE SYSTEM POOL WILL NOT UPDATE, WE HAVE TO RESTART IF THERE + // CHANGES TO EITHER OF THESE LOCATIONS: SSL_CERT_DIR, SSL_CERT_FILE + // + sslCertDir := os.Getenv("SSL_CERT_DIR") + sslCertFile := os.Getenv("SSL_CERT_FILE") + log.V(defaultLogLevel).Info("SSL environment", "SSL_CERT_DIR", sslCertDir, "SSL_CERT_FILE", sslCertFile) + + sslCertPaths := append(strings.Split(sslCertDir, ":"), sslCertFile) + sslCertPaths = slices.DeleteFunc(sslCertPaths, deleteEmptyStrings) + + cpw := &CertPoolWatcher{ + generation: 1, + dir: caDir, + sslCertPaths: sslCertPaths, + log: log, + done: make(chan bool), + } return cpw, nil } -func (cpw *CertPoolWatcher) update() { - cpw.log.Info("updating certificate pool") +func deleteEmptyStrings(p string) bool { + if p == "" { + return true + } + if _, err := os.Stat(p); err != nil { + return true + } + return false +} + +func (cpw *CertPoolWatcher) update(name string) { + cpw.log.Info("updating certificate pool", "file", name) pool, err := NewCertPool(cpw.dir, cpw.log) if err != nil { cpw.log.Error(err, "error updating certificate pool") @@ -115,6 +152,17 @@ func (cpw *CertPoolWatcher) update() { cpw.generation++ } +func (cpw *CertPoolWatcher) checkForRestart(name string) { + for _, p := range cpw.sslCertPaths { + if strings.Contains(name, p) { + cpw.log.Info("restarting due to file change", "file", name) + if cpw.restart != nil { + cpw.restart(0) + } + } + } +} + // Drain as many events as possible before doing anything // Otherwise, we will be hit with an event for _every_ entry in the // directory, and end up doing an update for each one @@ -124,7 +172,8 @@ func (cpw *CertPoolWatcher) drainEvents() { select { case <-drainTimer.C: return - case <-cpw.watcher.Events: + case e := <-cpw.watcher.Events: + cpw.checkForRestart(e.Name) } if !drainTimer.Stop() { <-drainTimer.C diff --git a/internal/shared/util/http/certpoolwatcher_test.go b/internal/shared/util/http/certpoolwatcher_test.go index ca13a478b6..f7d0a9f82d 100644 --- a/internal/shared/util/http/certpoolwatcher_test.go +++ b/internal/shared/util/http/certpoolwatcher_test.go @@ -11,6 +11,7 @@ import ( "math/big" "os" "path/filepath" + "sync/atomic" "testing" "time" @@ -61,7 +62,100 @@ func createCert(t *testing.T, name string) { // ignore the key } -func TestCertPoolWatcher(t *testing.T) { +func createTempCaDir(t *testing.T) string { + tmpCaDir, err := os.MkdirTemp("", "ca-dir") + require.NoError(t, err) + createCert(t, filepath.Join(tmpCaDir, "test1.pem")) + return tmpCaDir +} + +func TestCertPoolWatcherCaDir(t *testing.T) { + // create a temporary CA directory + tmpCaDir := createTempCaDir(t) + defer os.RemoveAll(tmpCaDir) + + os.Unsetenv("SSL_CERT_FILE") + os.Unsetenv("SSL_CERT_DIR") + + // Create the cert pool watcher + cpw, err := httputil.NewCertPoolWatcher(tmpCaDir, log.FromContext(context.Background())) + require.NoError(t, err) + require.NotNil(t, cpw) + defer cpw.Done() + restarted := &atomic.Bool{} + restarted.Store(false) + cpw.Restart(func(int) { restarted.Store(true) }) + err = cpw.Start(context.Background()) + require.NoError(t, err) + + // Get the original pool + firstPool, firstGen, err := cpw.Get() + require.NoError(t, err) + require.NotNil(t, firstPool) + + // Create a second cert in the CA directory + certName := filepath.Join(tmpCaDir, "test2.pem") + t.Logf("Create cert file at %q\n", certName) + createCert(t, certName) + + require.Eventually(t, func() bool { + secondPool, secondGen, err := cpw.Get() + if err != nil { + return false + } + // Should NOT restart, because this is not SSL_CERT_DIR nor SSL_CERT_FILE + return secondGen != firstGen && !firstPool.Equal(secondPool) && !restarted.Load() + }, 10*time.Second, time.Second) +} + +func TestCertPoolWatcherSslCertDir(t *testing.T) { + // create a temporary CA directory for SSL_CERT_DIR + tmpSslDir := createTempCaDir(t) + defer os.RemoveAll(tmpSslDir) + + // Update environment variables for the watcher - some of these should not exist + os.Unsetenv("SSL_CERT_FILE") + os.Setenv("SSL_CERT_DIR", tmpSslDir+":/tmp/does-not-exist.dir") + defer os.Unsetenv("SSL_CERT_DIR") + + // Create a different CaDir + tmpCaDir := createTempCaDir(t) + defer os.RemoveAll(tmpCaDir) + + // Create the cert pool watcher + cpw, err := httputil.NewCertPoolWatcher(tmpCaDir, log.FromContext(context.Background())) + require.NoError(t, err) + restarted := &atomic.Bool{} + restarted.Store(false) + cpw.Restart(func(int) { restarted.Store(true) }) + err = cpw.Start(context.Background()) + require.NoError(t, err) + defer cpw.Done() + + // Get the original pool + firstPool, firstGen, err := cpw.Get() + require.NoError(t, err) + require.NotNil(t, firstPool) + + // Create a second cert in SSL_CIR_DIR + certName := filepath.Join(tmpSslDir, "test2.pem") + t.Logf("Create cert file at %q\n", certName) + createCert(t, certName) + + require.Eventually(t, func() bool { + _, secondGen, err := cpw.Get() + if err != nil { + return false + } + // Because SSL_CERT_DIR is part of the SystemCertPool: + // 1. CPW only watches: it doesn't actually load it, that's the SystemCertPool's responsibility + // 2. Because the SystemCertPool never changes, we can't directly compare the pools + // 3. If SSL_CERT_DIR changes, we should expect a restart + return secondGen != firstGen && restarted.Load() + }, 10*time.Second, time.Second) +} + +func TestCertPoolWatcherSslCertFile(t *testing.T) { // create a temporary directory tmpDir, err := os.MkdirTemp("", "cert-pool") require.NoError(t, err) @@ -72,30 +166,78 @@ func TestCertPoolWatcher(t *testing.T) { t.Logf("Create cert file at %q\n", certName) createCert(t, certName) - // Update environment variables for the watcher - some of these should not exist - os.Setenv("SSL_CERT_DIR", tmpDir+":/tmp/does-not-exist.dir") - os.Setenv("SSL_CERT_FILE", "/tmp/does-not-exist.file") + // Update environment variables for the watcher + os.Unsetenv("SSL_CERT_DIR") + os.Setenv("SSL_CERT_FILE", certName) + defer os.Unsetenv("SSL_CERT_FILE") + + // Create a different CaDir + tmpCaDir := createTempCaDir(t) + defer os.RemoveAll(tmpCaDir) // Create the cert pool watcher - cpw, err := httputil.NewCertPoolWatcher(tmpDir, log.FromContext(context.Background())) + cpw, err := httputil.NewCertPoolWatcher(tmpCaDir, log.FromContext(context.Background())) require.NoError(t, err) + require.NotNil(t, cpw) defer cpw.Done() + restarted := &atomic.Bool{} + restarted.Store(false) + cpw.Restart(func(int) { restarted.Store(true) }) + err = cpw.Start(context.Background()) + require.NoError(t, err) // Get the original pool firstPool, firstGen, err := cpw.Get() require.NoError(t, err) require.NotNil(t, firstPool) - // Create a second cert - certName = filepath.Join(tmpDir, "test2.pem") + // Update the SSL_CERT_FILE t.Logf("Create cert file at %q\n", certName) createCert(t, certName) require.Eventually(t, func() bool { - secondPool, secondGen, err := cpw.Get() + _, secondGen, err := cpw.Get() if err != nil { return false } - return secondGen != firstGen && !firstPool.Equal(secondPool) - }, 30*time.Second, time.Second) + // Because SSL_CERT_FILE is part of the SystemCertPool: + // 1. CPW only watches: it doesn't actually load it, that's the SystemCertPool's responsibility + // 2. Because the SystemCertPool never changes, we can't directly compare the pools + // 3. If SSL_CERT_FILE changes, we should expect a restart + return secondGen != firstGen && restarted.Load() + }, 10*time.Second, time.Second) +} + +func TestCertPoolWatcherEmpty(t *testing.T) { + os.Unsetenv("SSL_CERT_FILE") + os.Unsetenv("SSL_CERT_DIR") + + // Create the empty cert pool watcher + cpw, err := httputil.NewCertPoolWatcher("", log.FromContext(context.Background())) + require.NoError(t, err) + require.NotNil(t, cpw) + defer cpw.Done() + err = cpw.Start(context.Background()) + require.NoError(t, err) + + pool, _, err := cpw.Get() + require.NoError(t, err) + require.NotNil(t, pool) +} + +func TestCertPoolInvalidPath(t *testing.T) { + os.Unsetenv("SSL_CERT_FILE") + os.Unsetenv("SSL_CERT_DIR") + + // Create an invalid cert pool watcher + cpw, err := httputil.NewCertPoolWatcher("/this/path/should/not/exist", log.FromContext(context.Background())) + require.NoError(t, err) + require.NotNil(t, cpw) + defer cpw.Done() + err = cpw.Start(context.Background()) + require.Error(t, err) + + pool, _, err := cpw.Get() + require.Error(t, err) + require.Nil(t, pool) } diff --git a/internal/shared/util/http/certutil.go b/internal/shared/util/http/certutil.go index fb7cdc4cbf..4dd83f0f49 100644 --- a/internal/shared/util/http/certutil.go +++ b/internal/shared/util/http/certutil.go @@ -35,7 +35,7 @@ func NewCertPool(caDir string, log logr.Logger) (*x509.CertPool, error) { log.V(defaultLogLevel).Info("skip directory", "name", e.Name()) continue } - log.V(defaultLogLevel).Info("load certificate", "name", e.Name(), "size", fi.Size(), "modtime", fi.ModTime()) + log.V(defaultLogLevel).Info("reading certificate file", "name", e.Name(), "size", fi.Size(), "modtime", fi.ModTime()) data, err := os.ReadFile(file) if err != nil { return nil, fmt.Errorf("error reading cert file %q: %w", file, err) @@ -44,7 +44,7 @@ func NewCertPool(caDir string, log logr.Logger) (*x509.CertPool, error) { if caCertPool.AppendCertsFromPEM(data) { count++ } - logPem(data, e.Name(), caDir, "loading certificate file", log) + logPem(data, e.Name(), caDir, "loading certificate", log) } // Found no certs!