Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions pkg/acquisition/acquisition_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,19 @@ func (*MockSourceCantRun) CanRun() error { return errors.New("can't run bro")
func (*MockSourceCantRun) GetName() string { return "mock_cant_run" }

// appendMockSource is only used to add mock source for tests.
func appendMockSource() {
registry.RegisterTestFactory("mock", func() types.DataSource { return &MockSource{} })
registry.RegisterTestFactory("mock_cant_run", func() types.DataSource { return &MockSourceCantRun{} })
func appendMockSource(t *testing.T) {
t.Helper()

restore := registry.RegisterTestFactory("mock", func() types.DataSource { return &MockSource{} })
t.Cleanup(restore)
restore = registry.RegisterTestFactory("mock_cant_run", func() types.DataSource { return &MockSourceCantRun{} })
t.Cleanup(restore)
}

func TestDataSourceConfigure(t *testing.T) {
ctx := t.Context()

appendMockSource()
appendMockSource(t)

tests := []struct {
TestName string
Expand Down Expand Up @@ -218,7 +222,7 @@ filename: foo.log
}

func TestLoadAcquisitionFromFiles(t *testing.T) {
appendMockSource()
appendMockSource(t)
t.Setenv("TEST_ENV", "test_value2")

ctx := t.Context()
Expand Down Expand Up @@ -554,7 +558,8 @@ func TestConfigureByDSN(t *testing.T) {
},
}

registry.RegisterTestFactory("mockdsn", func() types.DataSource { return &MockSourceByDSN{} })
restore := registry.RegisterTestFactory("mockdsn", func() types.DataSource { return &MockSourceByDSN{} })
t.Cleanup(restore)

for _, tc := range tests {
t.Run(tc.dsn, func(t *testing.T) {
Expand Down
65 changes: 47 additions & 18 deletions pkg/acquisition/registry/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,53 +3,82 @@ package registry
import (
"errors"
"fmt"
"sync"

"github.com/crowdsecurity/crowdsec/pkg/acquisition/types"
"github.com/crowdsecurity/crowdsec/pkg/cwversion/component"
)

// factoriesByName is filled at init time so the application can report
// if the datasources are unsupported, or simply excluded from the build.
// We don't need to guard with a mutex if all writes to the map are done
// inside init() functions, as they are guaranteed to run sequentially.
var factoriesByName = map[string]func() types.DataSource{}
var (
factoriesByName = map[string]types.DataSourceFactory{}
mu sync.RWMutex
)

func register(module string, factory types.DataSourceFactory) (restore func()) {
if module == "" {
panic("registry: datasource type is empty")
}

if factory == nil {
panic("registry: factory is nil for " + module)
}

mu.Lock()
prev, had := factoriesByName[module]
factoriesByName[module] = factory
mu.Unlock()

return func() {
mu.Lock()
if had {
factoriesByName[module] = prev
} else {
delete(factoriesByName, module)
}
mu.Unlock()
}
}

// RegisterFactory registers a datasource constructor in the factoriesByName map.
// It must be called in the init() function of the datasource package.
// In addition, the build component is registered so it will be reported
// by the "cscli version / crowdsec --version" commands.
func RegisterFactory(moduleName string, factory types.DataSourceFactory) {
component.Register("datasource_" + moduleName)
factoriesByName[moduleName] = factory
func RegisterFactory(module string, factory types.DataSourceFactory) {
component.Register("datasource_" + module)
register(module, factory)
}

// RegisterTestFactory does not attempt to register it as a component,
// production code should call RegisterFactory() instead and make the datasource
// code optional using the appropriate build tag.
// This function may be called outside init().
func RegisterTestFactory(moduleName string, factory types.DataSourceFactory) {
factoriesByName[moduleName] = factory
func RegisterTestFactory(module string, factory types.DataSourceFactory) (restore func()) {
return register(module, factory)
}

func LookupFactory(moduleName string) (types.DataSourceFactory, error) {
source, registered := factoriesByName[moduleName]
if registered {
return source, nil
func LookupFactory(module string) (types.DataSourceFactory, error) {
if module == "" {
return nil, errors.New("data source type is empty")
}

built, known := component.Built["datasource_"+moduleName]
mu.RLock()
factory, registered := factoriesByName[module]
mu.RUnlock()

if moduleName == "" {
return nil, errors.New("data source type is empty")
if registered {
return factory, nil
}

built, known := component.Built["datasource_"+module]
if !known {
return nil, fmt.Errorf("unknown data source %s", moduleName)
return nil, fmt.Errorf("unknown data source %s", module)
}

if built {
panic("datasource " + moduleName + " is built but not registered")
panic("datasource " + module + " is built but not registered")
}

return nil, fmt.Errorf("data source %s is not built in this version of crowdsec", moduleName)
return nil, fmt.Errorf("data source %s is not built in this version of crowdsec", module)
}