diff --git a/pkg/acquisition/acquisition_test.go b/pkg/acquisition/acquisition_test.go index 56685f345ef..471d21672d3 100644 --- a/pkg/acquisition/acquisition_test.go +++ b/pkg/acquisition/acquisition_test.go @@ -76,15 +76,19 @@ func (*MockSourceCantRun) CanRun() error { return errors.New("can't run bro") func (*MockSourceCantRun) GetName() string { return "mock_cant_run" } // appendMockSource is only used to add mock source for tests. -func appendMockSource() { - registry.RegisterTestFactory("mock", func() types.DataSource { return &MockSource{} }) - registry.RegisterTestFactory("mock_cant_run", func() types.DataSource { return &MockSourceCantRun{} }) +func appendMockSource(t *testing.T) { + t.Helper() + + restore := registry.RegisterTestFactory("mock", func() types.DataSource { return &MockSource{} }) + t.Cleanup(restore) + restore = registry.RegisterTestFactory("mock_cant_run", func() types.DataSource { return &MockSourceCantRun{} }) + t.Cleanup(restore) } func TestDataSourceConfigure(t *testing.T) { ctx := t.Context() - appendMockSource() + appendMockSource(t) tests := []struct { TestName string @@ -218,7 +222,7 @@ filename: foo.log } func TestLoadAcquisitionFromFiles(t *testing.T) { - appendMockSource() + appendMockSource(t) t.Setenv("TEST_ENV", "test_value2") ctx := t.Context() @@ -554,7 +558,8 @@ func TestConfigureByDSN(t *testing.T) { }, } - registry.RegisterTestFactory("mockdsn", func() types.DataSource { return &MockSourceByDSN{} }) + restore := registry.RegisterTestFactory("mockdsn", func() types.DataSource { return &MockSourceByDSN{} }) + t.Cleanup(restore) for _, tc := range tests { t.Run(tc.dsn, func(t *testing.T) { diff --git a/pkg/acquisition/registry/registry.go b/pkg/acquisition/registry/registry.go index a864de6a909..8aff1a097c6 100644 --- a/pkg/acquisition/registry/registry.go +++ b/pkg/acquisition/registry/registry.go @@ -3,6 +3,7 @@ package registry import ( "errors" "fmt" + "sync" "github.com/crowdsecurity/crowdsec/pkg/acquisition/types" "github.com/crowdsecurity/crowdsec/pkg/cwversion/component" @@ -10,46 +11,74 @@ import ( // factoriesByName is filled at init time so the application can report // if the datasources are unsupported, or simply excluded from the build. -// We don't need to guard with a mutex if all writes to the map are done -// inside init() functions, as they are guaranteed to run sequentially. -var factoriesByName = map[string]func() types.DataSource{} +var ( + factoriesByName = map[string]types.DataSourceFactory{} + mu sync.RWMutex +) + +func register(module string, factory types.DataSourceFactory) (restore func()) { + if module == "" { + panic("registry: datasource type is empty") + } + + if factory == nil { + panic("registry: factory is nil for " + module) + } + + mu.Lock() + prev, had := factoriesByName[module] + factoriesByName[module] = factory + mu.Unlock() + + return func() { + mu.Lock() + if had { + factoriesByName[module] = prev + } else { + delete(factoriesByName, module) + } + mu.Unlock() + } +} // RegisterFactory registers a datasource constructor in the factoriesByName map. // It must be called in the init() function of the datasource package. // In addition, the build component is registered so it will be reported // by the "cscli version / crowdsec --version" commands. -func RegisterFactory(moduleName string, factory types.DataSourceFactory) { - component.Register("datasource_" + moduleName) - factoriesByName[moduleName] = factory +func RegisterFactory(module string, factory types.DataSourceFactory) { + component.Register("datasource_" + module) + register(module, factory) } // RegisterTestFactory does not attempt to register it as a component, // production code should call RegisterFactory() instead and make the datasource // code optional using the appropriate build tag. // This function may be called outside init(). -func RegisterTestFactory(moduleName string, factory types.DataSourceFactory) { - factoriesByName[moduleName] = factory +func RegisterTestFactory(module string, factory types.DataSourceFactory) (restore func()) { + return register(module, factory) } -func LookupFactory(moduleName string) (types.DataSourceFactory, error) { - source, registered := factoriesByName[moduleName] - if registered { - return source, nil +func LookupFactory(module string) (types.DataSourceFactory, error) { + if module == "" { + return nil, errors.New("data source type is empty") } - built, known := component.Built["datasource_"+moduleName] + mu.RLock() + factory, registered := factoriesByName[module] + mu.RUnlock() - if moduleName == "" { - return nil, errors.New("data source type is empty") + if registered { + return factory, nil } + built, known := component.Built["datasource_"+module] if !known { - return nil, fmt.Errorf("unknown data source %s", moduleName) + return nil, fmt.Errorf("unknown data source %s", module) } if built { - panic("datasource " + moduleName + " is built but not registered") + panic("datasource " + module + " is built but not registered") } - return nil, fmt.Errorf("data source %s is not built in this version of crowdsec", moduleName) + return nil, fmt.Errorf("data source %s is not built in this version of crowdsec", module) }